/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction2;
import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CommonThreadPool;
import scala.Tuple2;

public class CpmmSPInstruction
extends AggregateBinarySPInstruction {
    private final boolean _outputEmptyBlocks;
    private final AggBinaryOp.SparkAggType _aggtype;

    private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean outputEmptyBlocks, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(SPInstruction.SPType.CPMM, op, in1, in2, out, opcode, istr);
        this._outputEmptyBlocks = outputEmptyBlocks;
        this._aggtype = aggtype;
    }

    public static CpmmSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("cpmm")) {
            throw new DMLRuntimeException("CpmmSPInstruction.parseInstruction(): Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
        boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
        AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[5]);
        return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        if (!this._outputEmptyBlocks || this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK || mc1.isNoEmptyBlocks() || mc2.isNoEmptyBlocks()) {
            in1 = in1.filter(new FilterNonEmptyBlocksFunction());
            in2 = in2.filter(new FilterNonEmptyBlocksFunction());
        }
        if (SparkUtils.isHashPartitioned(in1) && mc1.getNumRowBlocks() == 1L && mc2.getCols() == 1L) {
            if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                try {
                    CpmmMatrixVectorTask task = new CpmmMatrixVectorTask(in1, in2);
                    Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
                    LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? (LineageItem)this.getLineageItem(ec).getValue() : null;
                    sec.setMatrixOutputAndLineage(this.output.getName(), future_out, li);
                }
                catch (Exception ex) {
                    throw new DMLRuntimeException(ex);
                }
            } else {
                JavaRDD<MatrixBlock> out = in1.join(in2.mapToPair(new ReorgMapFunction("r'"))).values().map(new Cpmm2MultiplyFunction()).filter(new FilterNonEmptyBlocksFunction2());
                MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
                sec.setMatrixOutput(this.output.getName(), out2);
            }
        } else {
            int numPreferred = CpmmSPInstruction.getPreferredParJoin(mc1, mc2, in1.getNumPartitions(), in2.getNumPartitions());
            int numPartJoin = Math.min(CpmmSPInstruction.getMaxParJoin(mc1, mc2), numPreferred);
            JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new CpmmIndexFunction(true));
            JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new CpmmIndexFunction(false));
            if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
                if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                    try {
                        CpmmMatrixMatrixTask task = new CpmmMatrixMatrixTask(in1, in2, numPartJoin);
                        Future<MatrixBlock> future_out = CommonThreadPool.getDynamicPool().submit(task);
                        sec.setMatrixOutput(this.output.getName(), future_out);
                    }
                    catch (Exception ex) {
                        throw new DMLRuntimeException(ex);
                    }
                } else {
                    JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp1.join(tmp2, numPartJoin).mapToPair(new CpmmMultiplyFunction());
                    out = out.filter(new FilterNonEmptyBlocksFunction());
                    MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
                    sec.setMatrixOutput(this.output.getName(), out2);
                }
            } else {
                JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp1.join(tmp2, numPartJoin).mapToPair(new CpmmMultiplyFunction());
                if (!this._outputEmptyBlocks || mc1.isNoEmptyBlocks() || mc2.isNoEmptyBlocks()) {
                    out = out.filter(new FilterNonEmptyBlocksFunction());
                }
                out = RDDAggregateUtils.sumByKeyStable(out, false);
                sec.setRDDHandleForVariable(this.output.getName(), out);
                sec.addLineageRDD(this.output.getName(), this.input1.getName());
                sec.addLineageRDD(this.output.getName(), this.input2.getName());
                this.updateBinaryMMOutputDataCharacteristics(sec, true);
            }
        }
    }

    public AggBinaryOp.SparkAggType getAggType() {
        return this._aggtype;
    }

    private static int getPreferredParJoin(DataCharacteristics mc1, DataCharacteristics mc2, int numPar1, int numPar2) {
        int defPar = SparkExecutionContext.getDefaultParallelism(true);
        int maxParIn = Math.max(numPar1, numPar2);
        int maxSizeIn = SparkUtils.getNumPreferredPartitions(mc1) + SparkUtils.getNumPreferredPartitions(mc2);
        int tmp = mc1.dimsKnown(true) && mc2.dimsKnown(true) ? Math.max(maxSizeIn, maxParIn) : maxParIn;
        return tmp > defPar / 2 ? Math.max(tmp, defPar) : tmp;
    }

    private static int getMaxParJoin(DataCharacteristics mc1, DataCharacteristics mc2) {
        return mc1.colsKnown() ? (int)mc1.getNumColBlocks() : (mc2.rowsKnown() ? (int)mc2.getNumRowBlocks() : Integer.MAX_VALUE);
    }

    private static class CpmmMatrixMatrixTask
    implements Callable<MatrixBlock> {
        JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
        JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;
        int _numPartJoin;

        CpmmMatrixMatrixTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2, int nPartJoin) {
            this._in1 = in1;
            this._in2 = in2;
            this._numPartJoin = nPartJoin;
        }

        @Override
        public MatrixBlock call() {
            JavaPairRDD<Long, IndexedMatrixValue> tmp1 = this._in1.mapToPair(new CpmmIndexFunction(true));
            JavaPairRDD<Long, IndexedMatrixValue> tmp2 = this._in2.mapToPair(new CpmmIndexFunction(false));
            JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp1.join(tmp2, this._numPartJoin).mapToPair(new CpmmMultiplyFunction());
            out = out.filter(new FilterNonEmptyBlocksFunction());
            return RDDAggregateUtils.sumStable(out);
        }
    }

    private static class CpmmMatrixVectorTask
    implements Callable<MatrixBlock> {
        JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
        JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;

        CpmmMatrixVectorTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2) {
            this._in1 = in1;
            this._in2 = in2;
        }

        @Override
        public MatrixBlock call() {
            JavaRDD<MatrixBlock> out = this._in1.join(this._in2.mapToPair(new ReorgMapFunction("r'"))).values().map(new Cpmm2MultiplyFunction()).filter(new FilterNonEmptyBlocksFunction2());
            return RDDAggregateUtils.sumStable(out);
        }
    }

    private static class Cpmm2MultiplyFunction
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -3718880362385713416L;
        private AggregateBinaryOperator _op = null;
        private ReorgOperator _rop = null;

        private Cpmm2MultiplyFunction() {
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) throws Exception {
            if (this._op == null) {
                this._op = InstructionUtils.getMatMultOperator(1);
                this._rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject());
            }
            MatrixBlock in1 = (MatrixBlock)arg0._1();
            MatrixBlock in2 = ((MatrixBlock)arg0._2()).reorgOperations(this._rop, new MatrixBlock(), 0, 0, 0);
            return OperationsOnMatrixValues.matMult(in1, in2, new MatrixBlock(), this._op);
        }
    }

    private static class CpmmMultiplyFunction
    implements PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2009255629093036642L;
        private AggregateBinaryOperator _op = null;

        private CpmmMultiplyFunction() {
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>> arg0) throws Exception {
            if (this._op == null) {
                this._op = InstructionUtils.getMatMultOperator(1);
            }
            MatrixBlock blkIn1 = (MatrixBlock)((IndexedMatrixValue)((Tuple2)arg0._2())._1()).getValue();
            MatrixBlock blkIn2 = (MatrixBlock)((IndexedMatrixValue)((Tuple2)arg0._2())._2()).getValue();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = OperationsOnMatrixValues.matMult(blkIn1, blkIn2, new MatrixBlock(), this._op);
            ixOut.setIndexes(((IndexedMatrixValue)((Tuple2)arg0._2())._1()).getIndexes().getRowIndex(), ((IndexedMatrixValue)((Tuple2)arg0._2())._2()).getIndexes().getColumnIndex());
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class CpmmIndexFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue> {
        private static final long serialVersionUID = -1187183128301671162L;
        private final boolean _left;

        public CpmmIndexFunction(boolean left) {
            this._left = left;
        }

        public Tuple2<Long, IndexedMatrixValue> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            IndexedMatrixValue value = new IndexedMatrixValue((MatrixIndexes)arg0._1(), (MatrixValue)arg0._2());
            Long key = this._left ? ((MatrixIndexes)arg0._1).getColumnIndex() : ((MatrixIndexes)arg0._1).getRowIndex();
            return new Tuple2((Object)key, (Object)value);
        }
    }
}

