/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.llm;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
import org.apache.seatunnel.api.table.catalog.SeaTunnelDataTypeConvertorUtil;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.transform.common.SeaTunnelRowAccessor;
import org.apache.seatunnel.transform.common.SingleFieldOutputTransform;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;

public class LLMTransform
extends SingleFieldOutputTransform {
    private final ReadonlyConfig config;
    private final SeaTunnelDataType<?> outputDataType;
    private Model model;

    public LLMTransform(@NonNull ReadonlyConfig config, @NonNull CatalogTable inputCatalogTable) {
        super(inputCatalogTable);
        if (config == null) {
            throw new NullPointerException("config is marked non-null but is null");
        }
        if (inputCatalogTable == null) {
            throw new NullPointerException("inputCatalogTable is marked non-null but is null");
        }
        this.config = config;
        this.outputDataType = SeaTunnelDataTypeConvertorUtil.deserializeSeaTunnelDataType((String)"output", (String)((SqlType)config.get(LLMTransformConfig.OUTPUT_DATA_TYPE)).toString());
    }

    private void tryOpen() {
        if (this.model == null) {
            this.open();
        }
    }

    public String getPluginName() {
        return "LLM";
    }

    public void open() {
        ModelProvider provider = (ModelProvider)((Object)this.config.get(ModelTransformConfig.MODEL_PROVIDER));
        switch (provider) {
            case CUSTOM: {
                ReadonlyConfig customConfig = this.config.getOptional(ModelTransformConfig.CustomRequestConfig.CUSTOM_CONFIG).map(ReadonlyConfig::fromMap).orElseThrow(() -> new IllegalArgumentException("Custom config can't be null"));
                this.model = new CustomModel(this.inputCatalogTable.getSeaTunnelRowType(), this.outputDataType.getSqlType(), (List)this.config.get(LLMTransformConfig.INFERENCE_COLUMNS), (String)this.config.get(LLMTransformConfig.PROMPT), (String)this.config.get(LLMTransformConfig.MODEL), provider.usedLLMPath((String)this.config.get(LLMTransformConfig.API_PATH)), (Map)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_HEADERS), (Map)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_BODY), (String)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_RESPONSE_PARSE));
                break;
            }
            case MICROSOFT: {
                this.model = new MicrosoftModel(this.inputCatalogTable.getSeaTunnelRowType(), this.outputDataType.getSqlType(), (List)this.config.get(LLMTransformConfig.INFERENCE_COLUMNS), (String)this.config.get(LLMTransformConfig.PROMPT), (String)this.config.get(LLMTransformConfig.MODEL), (String)this.config.get(LLMTransformConfig.API_KEY), provider.usedLLMPath((String)this.config.get(LLMTransformConfig.API_PATH)));
                break;
            }
            case OPENAI: 
            case DOUBAO: {
                this.model = new OpenAIModel(this.inputCatalogTable.getSeaTunnelRowType(), this.outputDataType.getSqlType(), (List)this.config.get(LLMTransformConfig.INFERENCE_COLUMNS), (String)this.config.get(LLMTransformConfig.PROMPT), (String)this.config.get(LLMTransformConfig.MODEL), (String)this.config.get(LLMTransformConfig.API_KEY), provider.usedLLMPath((String)this.config.get(LLMTransformConfig.API_PATH)));
            }
            case KIMIAI: {
                this.model = new KimiAIModel(this.inputCatalogTable.getSeaTunnelRowType(), this.outputDataType.getSqlType(), (List)this.config.get(LLMTransformConfig.INFERENCE_COLUMNS), (String)this.config.get(LLMTransformConfig.PROMPT), (String)this.config.get(LLMTransformConfig.MODEL), (String)this.config.get(LLMTransformConfig.API_KEY), provider.usedLLMPath((String)this.config.get(LLMTransformConfig.API_PATH)));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported model provider: " + (Object)((Object)provider));
            }
        }
    }

    @Override
    protected Object getOutputFieldValue(SeaTunnelRowAccessor inputRow) {
        this.tryOpen();
        SeaTunnelRow seaTunnelRow = new SeaTunnelRow(inputRow.getFields());
        try {
            List<String> values = this.model.inference(Collections.singletonList(seaTunnelRow));
            switch (this.outputDataType.getSqlType()) {
                case STRING: {
                    return String.valueOf(values.get(0));
                }
                case INT: {
                    return Integer.parseInt(values.get(0));
                }
                case BIGINT: {
                    return Long.parseLong(values.get(0));
                }
                case DOUBLE: {
                    return Double.parseDouble(values.get(0));
                }
                case BOOLEAN: {
                    return Boolean.parseBoolean(values.get(0));
                }
            }
            throw new IllegalArgumentException("Unsupported output data type: " + this.outputDataType);
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("Failed to inference model with row %s", seaTunnelRow), e);
        }
    }

    @Override
    protected Column getOutputColumn() {
        String customFieldName = (String)this.config.get(LLMTransformConfig.OUTPUT_COLUMN_NAME);
        String[] fieldNames = this.inputCatalogTable.getTableSchema().getFieldNames();
        boolean isExist = Arrays.asList(fieldNames).contains(customFieldName);
        if (isExist) {
            throw new IllegalArgumentException(String.format("llm inference field name %s already exists", customFieldName));
        }
        return PhysicalColumn.of((String)customFieldName, this.outputDataType, (Long)null, (boolean)true, null, (String)"Output column of LLM");
    }

    public void close() {
        if (this.model != null) {
            this.model.close();
        }
    }
}

