/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageParser;

public class AutoDiff {
    private static final String ADVARPREFIX = "adVar";
    private static final boolean DEBUG = false;

    public static ListObject getBackward(MatrixObject mo, ArrayList<Data> lineage, ExecutionContext adec) {
        ArrayList<String> names = new ArrayList<String>();
        String lin = lineage.get(0).toString();
        lin = lin.replace("foo", "");
        List<Data> data = AutoDiff.parseNComputeAutoDiffFromLineage(mo, lin, names, adec);
        return new ListObject(data, names);
    }

    public static List<Data> parseNComputeAutoDiffFromLineage(MatrixObject mo, String mainTrace, ArrayList<String> names, ExecutionContext ec) {
        LineageItem root = LineageParser.parseLineageTrace(mainTrace);
        root.resetVisitStatusNR();
        HashMap<Long, Hop> operands = new HashMap<Long, Hop>();
        ec.setVariable("X", mo);
        DataOp input = HopRewriteUtils.createTransientRead("X", mo);
        ArrayList<Hop> allHops = AutoDiff.constructHopsNR(root, operands, input, names);
        ArrayList<Data> results = new ArrayList<Data>();
        for (int i = 0; i < allHops.size(); ++i) {
            DataOp dop = HopRewriteUtils.createTransientWrite("advar" + i, allHops.get(i));
            ArrayList<Instruction> dInst = Recompiler.recompileHopsDag(dop, ec.getVariables(), null, true, true, 0L);
            AutoDiff.executeInst(dInst, ec);
            results.add(ec.getVariable("advar" + i));
        }
        return results;
    }

    public static ArrayList<Hop> constructHopsNR(LineageItem item, Map<Long, Hop> operands, Hop mo, ArrayList<String> names) {
        ArrayList<Hop> allHops = new ArrayList<Hop>();
        Stack<LineageItem> stackItem = new Stack<LineageItem>();
        Stack<MutableInt> stackPos = new Stack<MutableInt>();
        stackItem.push(item);
        stackPos.push(new MutableInt(0));
        while (!stackItem.empty()) {
            LineageItem tmpItem = (LineageItem)stackItem.peek();
            MutableInt tmpPos = (MutableInt)stackPos.peek();
            if (tmpItem.isVisited()) {
                stackItem.pop();
                stackPos.pop();
                continue;
            }
            if (tmpItem.getInputs() == null || tmpItem.getInputs().length <= tmpPos.intValue()) {
                AutoDiff.constructSingleHop(tmpItem, operands, mo, allHops, names);
                stackItem.pop();
                stackPos.pop();
                tmpItem.setVisited();
                continue;
            }
            if (tmpItem.getInputs() == null) continue;
            stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
            tmpPos.increment();
            stackPos.push(new MutableInt(0));
        }
        return allHops;
    }

    private static void constructSingleHop(LineageItem item, Map<Long, Hop> operands, Hop mo, ArrayList<Hop> allHops, ArrayList<String> names) {
        block0 : switch (item.getType()) {
            case Creation: {
                if (item.getData().startsWith(ADVARPREFIX)) {
                    long phId = Long.parseLong(item.getData().substring(3));
                    Hop input = operands.get(phId);
                    operands.remove(phId);
                    operands.put(item.getId(), input);
                    break;
                }
                Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
                if (inst instanceof DataGenCPInstruction) {
                    DataGenCPInstruction rand = (DataGenCPInstruction)inst;
                    HashMap<String, Hop> params = new HashMap<String, Hop>();
                    if (rand.getOpcode().equals("rand")) {
                        if (rand.output.getDataType() == Types.DataType.TENSOR) {
                            params.put("dims", new LiteralOp(rand.getDims()));
                        } else {
                            params.put("rows", new LiteralOp(rand.getRows()));
                            params.put("cols", new LiteralOp(rand.getCols()));
                        }
                        params.put("min", new LiteralOp(rand.getMinValue()));
                        params.put("max", new LiteralOp(rand.getMaxValue()));
                        params.put("pdf", new LiteralOp(rand.getPdf()));
                        params.put("lambda", new LiteralOp(rand.getPdfParams()));
                        params.put("sparsity", new LiteralOp(rand.getSparsity()));
                        params.put("seed", new LiteralOp(rand.getSeed()));
                    }
                    DataGenOp datagen = new DataGenOp(Types.OpOpDG.valueOf(rand.getOpcode().toUpperCase()), new DataIdentifier("tmp"), params);
                    datagen.setBlocksize(rand.getBlocksize());
                    operands.put(item.getId(), datagen);
                    break;
                }
                if (inst instanceof VariableCPInstruction && ((VariableCPInstruction)inst).isCreateVariable()) {
                    String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
                    Types.DataType dt = Types.DataType.valueOf(parts[4]);
                    Types.ValueType vt = dt == Types.DataType.MATRIX ? Types.ValueType.FP64 : Types.ValueType.STRING;
                    HashMap<String, Hop> params = new HashMap<String, Hop>();
                    params.put("iofilename", new LiteralOp(parts[2]));
                    params.put("rows", new LiteralOp(Long.parseLong(parts[6])));
                    params.put("cols", new LiteralOp(Long.parseLong(parts[7])));
                    params.put("nnz", new LiteralOp(Long.parseLong(parts[8])));
                    params.put("format", new LiteralOp(parts[5]));
                    DataOp pread = new DataOp(parts[1].substring(5), dt, vt, Types.OpOpData.PERSISTENTREAD, params);
                    pread.setFileName(parts[2]);
                    operands.put(item.getId(), pread);
                    break;
                }
                if (!(inst instanceof RandSPInstruction)) break;
                RandSPInstruction rand = (RandSPInstruction)inst;
                HashMap<String, Hop> params = new HashMap<String, Hop>();
                if (rand.output.getDataType() == Types.DataType.TENSOR) {
                    params.put("dims", new LiteralOp(rand.getDims()));
                } else {
                    params.put("rows", new LiteralOp(rand.getRows()));
                    params.put("cols", new LiteralOp(rand.getCols()));
                }
                params.put("min", new LiteralOp(rand.getMinValue()));
                params.put("max", new LiteralOp(rand.getMaxValue()));
                params.put("pdf", new LiteralOp(rand.getPdf()));
                params.put("lambda", new LiteralOp(rand.getPdfParams()));
                params.put("sparsity", new LiteralOp(rand.getSparsity()));
                params.put("seed", new LiteralOp(rand.getSeed()));
                DataGenOp datagen = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("tmp"), params);
                datagen.setBlocksize(rand.getBlocksize());
                operands.put(item.getId(), datagen);
                break;
            }
            case Instruction: {
                CPInstruction.CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
                if (ctype == null) break;
                switch (ctype) {
                    case AggregateBinary: {
                        Hop input1 = operands.get(item.getInputs()[0].getId());
                        Hop input2 = operands.get(item.getInputs()[1].getId());
                        ReorgOp trasnX = HopRewriteUtils.createTranspose(input1);
                        ReorgOp trasnW = HopRewriteUtils.createTranspose(input2);
                        AggBinaryOp dX = HopRewriteUtils.createMatrixMultiply(mo, trasnW);
                        AggBinaryOp dW = HopRewriteUtils.createMatrixMultiply(trasnX, mo);
                        operands.put(item.getId(), dX);
                        operands.put(item.getId() + 1L, dW);
                        allHops.add(dX);
                        allHops.add(dW);
                        names.add("dX");
                        names.add("dW");
                        break block0;
                    }
                    case Binary: {
                        String opcode = item.getOpcode();
                        AggUnaryOp output = null;
                        if (opcode.equals("+")) {
                            output = HopRewriteUtils.createAggUnaryOp(mo, Types.AggOp.SUM, Types.Direction.Col);
                        }
                        operands.put(item.getId(), output);
                        allHops.add(output);
                        names.add("dB");
                        break block0;
                    }
                }
                throw new DMLRuntimeException("Unsupported autoDiff instruction type: " + ctype.name() + " (" + item.getOpcode() + ").");
            }
            case Literal: {
                CPOperand op = new CPOperand(item.getData());
                operands.put(item.getId(), ScalarObjectFactory.createLiteralOp(op.getValueType(), op.getName()));
                break;
            }
            default: {
                throw new DMLRuntimeException("Lineage type " + (Object)((Object)item.getType()) + " is not supported");
            }
        }
    }

    private static void executeInst(ArrayList<Instruction> newInst, ExecutionContext lrwec) {
        try {
            BasicProgramBlock pb = new BasicProgramBlock(new Program());
            pb.setInstructions(newInst);
            pb.execute(lrwec);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Error executing autoDiff instruction", e);
        }
    }
}

