/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.embadding;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.transform.common.MultipleFieldOutputTransform;
import org.apache.seatunnel.transform.common.SeaTunnelRowAccessor;
import org.apache.seatunnel.transform.exception.TransformCommonError;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.embadding.EmbeddingTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.embadding.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.embadding.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.embadding.remote.doubao.DoubaoModel;
import org.apache.seatunnel.transform.nlpmodel.embadding.remote.openai.OpenAIModel;
import org.apache.seatunnel.transform.nlpmodel.embadding.remote.qianfan.QianfanModel;

public class EmbeddingTransform
extends MultipleFieldOutputTransform {
    private final ReadonlyConfig config;
    private List<String> fieldNames;
    private List<Integer> fieldOriginalIndexes;
    private Model model;
    private Integer dimension;

    public EmbeddingTransform(@NonNull ReadonlyConfig config, @NonNull CatalogTable inputCatalogTable) {
        super(inputCatalogTable);
        if (config == null) {
            throw new NullPointerException("config is marked non-null but is null");
        }
        if (inputCatalogTable == null) {
            throw new NullPointerException("inputCatalogTable is marked non-null but is null");
        }
        this.config = config;
        this.initOutputFields(inputCatalogTable.getTableSchema().toPhysicalRowDataType(), (Map)config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS));
    }

    private void tryOpen() {
        if (this.model == null) {
            this.open();
        }
    }

    public void open() {
        ModelProvider provider = (ModelProvider)((Object)this.config.get(ModelTransformConfig.MODEL_PROVIDER));
        try {
            switch (provider) {
                case CUSTOM: {
                    ReadonlyConfig customConfig = this.config.getOptional(ModelTransformConfig.CustomRequestConfig.CUSTOM_CONFIG).map(ReadonlyConfig::fromMap).orElseThrow(() -> new IllegalArgumentException("Custom config can't be null"));
                    this.model = new CustomModel((String)this.config.get(ModelTransformConfig.MODEL), provider.usedEmbeddingPath((String)this.config.get(ModelTransformConfig.API_PATH)), (Map)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_HEADERS), (Map)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_REQUEST_BODY), (String)customConfig.get(ModelTransformConfig.CustomRequestConfig.CUSTOM_RESPONSE_PARSE), (Integer)this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                }
                case OPENAI: {
                    this.model = new OpenAIModel((String)this.config.get(ModelTransformConfig.API_KEY), (String)this.config.get(ModelTransformConfig.MODEL), provider.usedEmbeddingPath((String)this.config.get(ModelTransformConfig.API_PATH)), (Integer)this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                }
                case DOUBAO: {
                    this.model = new DoubaoModel((String)this.config.get(ModelTransformConfig.API_KEY), (String)this.config.get(ModelTransformConfig.MODEL), provider.usedEmbeddingPath((String)this.config.get(ModelTransformConfig.API_PATH)), (Integer)this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                }
                case QIANFAN: {
                    this.model = new QianfanModel((String)this.config.get(ModelTransformConfig.API_KEY), (String)this.config.get(ModelTransformConfig.SECRET_KEY), (String)this.config.get(ModelTransformConfig.MODEL), provider.usedEmbeddingPath((String)this.config.get(ModelTransformConfig.API_PATH)), (String)this.config.get(ModelTransformConfig.OAUTH_PATH), (Integer)this.config.get(EmbeddingTransformConfig.SINGLE_VECTORIZED_INPUT_NUMBER));
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported model provider: " + (Object)((Object)provider));
                }
            }
            this.dimension = this.model.dimension();
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to initialize model", e);
        }
    }

    private void initOutputFields(SeaTunnelRowType inputRowType, Map<String, String> fields) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Integer> fieldOriginalIndexes = new ArrayList<Integer>();
        for (Map.Entry<String, String> field : fields.entrySet()) {
            int srcFieldIndex;
            String srcField = field.getValue();
            try {
                srcFieldIndex = inputRowType.indexOf(srcField);
            }
            catch (IllegalArgumentException e) {
                throw TransformCommonError.cannotFindInputFieldError(this.getPluginName(), srcField);
            }
            fieldNames.add(field.getKey());
            fieldOriginalIndexes.add(srcFieldIndex);
        }
        this.fieldNames = fieldNames;
        this.fieldOriginalIndexes = fieldOriginalIndexes;
    }

    @Override
    protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
        this.tryOpen();
        try {
            Object[] fieldArray = new Object[this.fieldOriginalIndexes.size()];
            for (int i = 0; i < this.fieldOriginalIndexes.size(); ++i) {
                fieldArray[i] = inputRow.getField(this.fieldOriginalIndexes.get(i));
            }
            List<ByteBuffer> vectorization = this.model.vectorization(fieldArray);
            return vectorization.toArray();
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to data vectorization", e);
        }
    }

    @Override
    protected Column[] getOutputColumns() {
        Column[] columns = new Column[this.fieldNames.size()];
        for (int i = 0; i < this.fieldNames.size(); ++i) {
            columns[i] = PhysicalColumn.of((String)this.fieldNames.get(i), (SeaTunnelDataType)VectorType.VECTOR_FLOAT_TYPE, null, (Integer)this.dimension, (boolean)true, (Object)"", (String)"");
        }
        return columns;
    }

    public String getPluginName() {
        return "Embedding";
    }

    public void close() {
        if (this.model != null) {
            this.model.close();
        }
    }
}

