/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.ConstIdentifier;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.Expression;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.dml.DmlSyntacticValidator;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;

public class EvalNaryCPInstruction
extends BuiltinNaryCPInstruction {
    private int _threadID = 0;

    public EvalNaryCPInstruction(Operator op, String opcode, String istr, CPOperand output, CPOperand ... inputs) {
        super(op, opcode, istr, output, inputs);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void processInstruction(ExecutionContext ec) {
        Object funcName = ec.getScalarInput(this.inputs[0]).getStringValue();
        String nsName = null;
        if (((String)funcName).contains("::")) {
            String[] parts = DMLProgram.splitFunctionKey((String)funcName);
            funcName = parts[1];
            nsName = parts[0];
        }
        CPOperand[] boundInputs = Arrays.copyOfRange(this.inputs, 1, this.inputs.length);
        MatrixObject outputMO = !this.output.isMatrix() ? null : new MatrixObject(ec.getMatrixObject(this.output.getName()));
        Types.DataType dt1 = boundInputs[0].getDataType().isList() ? Types.DataType.MATRIX : boundInputs[0].getDataType();
        String funcName2 = Builtins.getInternalFName((String)funcName, dt1);
        if (!ec.getProgram().containsFunctionProgramBlock(nsName, (String)funcName)) {
            if (!Builtins.contains((String)funcName, true, false) && !ec.getProgram().containsFunctionProgramBlock(".builtinNS", funcName2)) {
                String msgNs = nsName == null ? ".defaultNS" : nsName;
                throw new DMLRuntimeException("Function '" + DMLProgram.constructFunctionKey(msgNs, (String)funcName) + "' (called through eval) is non-existing.");
            }
            nsName = ".builtinNS";
            Program msgNs = ec.getProgram();
            synchronized (msgNs) {
                if (!ec.getProgram().containsFunctionProgramBlock(nsName, funcName2)) {
                    EvalNaryCPInstruction.compileFunctionProgramBlock((String)funcName, dt1, ec.getProgram());
                }
            }
            funcName = funcName2;
        }
        FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(nsName, (String)funcName, false);
        if (ProgramBlock.isThreadID(this._threadID)) {
            String funcNameParfor = (String)funcName + "_t" + this._threadID;
            if (!ec.getProgram().containsFunctionProgramBlock(nsName, funcNameParfor, false)) {
                fpb = ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<String>(), new HashSet<String>(), this._threadID);
                ec.getProgram().addFunctionProgramBlock(nsName, funcNameParfor, fpb, false);
                ec.addTmpParforFunction(DMLProgram.constructFunctionKey(nsName, funcNameParfor));
            }
            fpb = ec.getProgram().getFunctionProgramBlock(nsName, funcNameParfor, false);
            funcName = funcNameParfor;
        }
        CPOperand[] boundInputs2 = null;
        LineageItem[] lineageInputs = null;
        if (boundInputs.length == 1 && boundInputs[0].getDataType().isList() && (fpb.getInputParams().size() != 1 || !fpb.getInputParams().get(0).getDataType().isList())) {
            ListObject lo = ec.getListObject(boundInputs[0]);
            lo = lo.isNamedList() ? EvalNaryCPInstruction.appendNamedDefaults(lo, fpb.getStatementBlock()) : EvalNaryCPInstruction.appendPositionalDefaults(lo, fpb.getStatementBlock());
            EvalNaryCPInstruction.checkValidArguments(lo.getData(), lo.getNames(), fpb.getInputParamNames());
            if (lo.isNamedList()) {
                lo = EvalNaryCPInstruction.reorderNamedListForFunctionCall(lo, fpb.getInputParamNames());
            }
            boundInputs2 = new CPOperand[lo.getLength()];
            for (int i = 0; i < lo.getLength(); ++i) {
                Data in = lo.getData(i);
                String varName = Dag.getNextUniqueVarname(in.getDataType());
                ec.getVariables().put(varName, in);
                boundInputs2[i] = new CPOperand(varName, in);
            }
            boundInputs = boundInputs2;
            lineageInputs = !DMLScript.LINEAGE ? null : lo.getLineageItems().toArray(new LineageItem[lo.getLength()]);
        }
        ArrayList<String> boundOutputNames = new ArrayList<String>();
        if (this.output.getDataType().isMatrix()) {
            boundOutputNames.add(this.output.getName());
        } else {
            boundOutputNames.addAll(fpb.getOutputParamNames());
        }
        FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(nsName, (String)funcName, false, boundInputs, lineageInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
        fcpi.processInstruction(ec);
        if (this.output.getDataType().isMatrix()) {
            Data newOutput = ec.getVariable(this.output);
            if (!(newOutput instanceof MatrixObject)) {
                MatrixBlock mb = null;
                if (newOutput instanceof ScalarObject) {
                    mb = new MatrixBlock(((ScalarObject)newOutput).getDoubleValue());
                } else if (newOutput instanceof FrameObject) {
                    mb = DataConverter.convertToMatrixBlock((FrameBlock)((FrameObject)newOutput).acquireRead());
                    ec.cleanupCacheableData((FrameObject)newOutput);
                } else {
                    throw new DMLRuntimeException("Invalid eval return type: " + newOutput.getDataType().name() + " (valid: matrix/frame/scalar; where frames or scalars are converted to output matrices)");
                }
                outputMO.acquireModify(mb);
                outputMO.release();
                ec.setVariable(this.output.getName(), outputMO);
            }
        } else {
            Data[] ldata = (Data[])boundOutputNames.stream().map(n -> ec.getVariable((String)n)).toArray(Data[]::new);
            String[] lnames = boundOutputNames.toArray(new String[0]);
            ListObject listOutput = new ListObject(ldata, lnames);
            ec.setVariable(this.output.getName(), listOutput);
        }
        if (boundInputs2 != null) {
            for (CPOperand op : boundInputs2) {
                VariableCPInstruction.processRmvarInstruction(ec, op.getName());
            }
        }
    }

    @Override
    public void updateInstructionThreadID(String pattern, String replace) {
        this._threadID = Integer.parseInt(replace.substring("_t".length()));
    }

    private static void compileFunctionProgramBlock(String name, Types.DataType dt, Program prog) {
        String nsName = ".builtinNS";
        Map<String, FunctionStatementBlock> fsbs = DmlSyntacticValidator.loadAndParseBuiltinFunction(name, nsName, true);
        if (fsbs.isEmpty()) {
            throw new DMLRuntimeException("Failed to compile function '" + name + "'.");
        }
        DMLProgram dmlp = prog.getDMLProg() != null ? prog.getDMLProg() : fsbs.get(Builtins.getInternalFName(name, dt)).getDMLProg();
        fsbs = dmlp.getBuiltinFunctionDictionary() == null ? fsbs : fsbs.entrySet().stream().filter(e -> !dmlp.getBuiltinFunctionDictionary().containsFunction((String)e.getKey())).collect(Collectors.toMap(e -> (String)e.getKey(), e -> (FunctionStatementBlock)e.getValue()));
        for (Map.Entry<String, FunctionStatementBlock> fsb : fsbs.entrySet()) {
            dmlp.createNamespace(nsName);
            dmlp.addFunctionStatementBlock(nsName, fsb.getKey(), fsb.getValue());
            fsb.getValue().setDMLProg(dmlp);
        }
        DMLTranslator dmlt = new DMLTranslator(dmlp);
        ProgramRewriter rewriter = new ProgramRewriter(true, false);
        ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
        for (FunctionStatementBlock functionStatementBlock : fsbs.values()) {
            dmlt.liveVariableAnalysisFunction(dmlp, functionStatementBlock);
            dmlt.validateFunction(dmlp, functionStatementBlock, true);
        }
        for (FunctionStatementBlock functionStatementBlock : fsbs.values()) {
            dmlt.constructHops(functionStatementBlock);
            rewriter.rewriteHopDAGsFunction(functionStatementBlock, false);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            rewriter.rewriteHopDAGsFunction(functionStatementBlock, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            rewriter2.rewriteHopDAGsFunction(functionStatementBlock, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            HopRewriteUtils.setUnoptimizedFunctionCalls(functionStatementBlock);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            DMLTranslator.refreshMemEstimates(functionStatementBlock);
            dmlt.constructLops(functionStatementBlock);
        }
        for (Map.Entry entry : fsbs.entrySet()) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)dmlt.createRuntimeProgramBlock(prog, (StatementBlock)entry.getValue(), ConfigurationManager.getDMLConfig());
            if (!prog.containsFunctionProgramBlock(nsName, (String)entry.getKey(), true)) {
                prog.addFunctionProgramBlock(nsName, (String)entry.getKey(), fpb, true);
            }
            if (prog.containsFunctionProgramBlock(nsName, (String)entry.getKey(), false)) continue;
            prog.addFunctionProgramBlock(nsName, (String)entry.getKey(), fpb, false);
        }
    }

    private static ListObject appendNamedDefaults(ListObject params, StatementBlock sb) {
        if (!params.isNamedList() || sb == null) {
            return params;
        }
        FunctionStatement fstmt = (FunctionStatement)sb.getStatement(0);
        ListObject ret = new ListObject(params);
        for (int i = 0; i < fstmt.getInputParams().size(); ++i) {
            String param = fstmt.getInputParamNames()[i];
            if (ret.contains(param) || fstmt.getInputDefaults().get(i) == null || !fstmt.getInputParams().get(i).getDataType().isScalar()) continue;
            Types.ValueType vt = fstmt.getInputParams().get(i).getValueType();
            Expression expr = fstmt.getInputDefaults().get(i);
            if (!(expr instanceof ConstIdentifier)) continue;
            ScalarObject sobj = ScalarObjectFactory.createScalarObject(vt, expr.toString());
            LineageItem litem = !DMLScript.LINEAGE ? null : LineageItemUtils.createScalarLineageItem(ScalarObjectFactory.createLiteralOp(sobj));
            ret.add(param, sobj, litem);
        }
        return ret;
    }

    private static ListObject appendPositionalDefaults(ListObject params, StatementBlock sb) {
        if (sb == null) {
            return params;
        }
        FunctionStatement fstmt = (FunctionStatement)sb.getStatement(0);
        ListObject ret = new ListObject(params);
        for (int i = ret.getLength(); i < fstmt.getInputParams().size(); ++i) {
            String param = fstmt.getInputParamNames()[i];
            if (fstmt.getInputDefaults().get(i) == null || !fstmt.getInputParams().get(i).getDataType().isScalar() || !(fstmt.getInputDefaults().get(i) instanceof ConstIdentifier)) {
                throw new DMLRuntimeException("Unable to append positional scalar default for '" + param + "'");
            }
            Types.ValueType vt = fstmt.getInputParams().get(i).getValueType();
            Expression expr = fstmt.getInputDefaults().get(i);
            ScalarObject sobj = ScalarObjectFactory.createScalarObject(vt, expr.toString());
            LineageItem litem = !DMLScript.LINEAGE ? null : LineageItemUtils.createScalarLineageItem(ScalarObjectFactory.createLiteralOp(sobj));
            ret.add(sobj, litem);
        }
        return ret;
    }

    private static void checkValidArguments(List<Data> loData, List<String> loNames, List<String> fArgNames) {
        int listSize;
        int n = listSize = loNames != null ? loNames.size() : loData.size();
        if (listSize != fArgNames.size()) {
            throw new DMLRuntimeException("Failed to expand list for function call (mismatching number of arguments: " + listSize + " vs. " + fArgNames.size() + ").");
        }
        if (loNames != null) {
            HashSet<String> probe = new HashSet<String>();
            for (String var : fArgNames) {
                probe.add(var);
            }
            for (String var : loNames) {
                if (probe.contains(var)) continue;
                throw new DMLRuntimeException("List argument named '" + var + "' not in function signature.");
            }
        }
    }

    private static ListObject reorderNamedListForFunctionCall(ListObject in, List<String> fArgNames) {
        ArrayList<Data> sortedData = new ArrayList<Data>();
        ArrayList<LineageItem> sortedLI = DMLScript.LINEAGE ? new ArrayList<LineageItem>() : null;
        for (String name : fArgNames) {
            sortedData.add(in.getData(name));
            if (!DMLScript.LINEAGE) continue;
            sortedLI.add(in.getLineageItem(name));
        }
        return new ListObject(sortedData, new ArrayList<String>(fArgNames), sortedLI);
    }
}

