/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

public class AggregateUnarySPInstruction
extends UnarySPInstruction {
    private AggBinaryOp.SparkAggType _aggtype = null;
    private AggregateOperator _aop = null;

    protected AggregateUnarySPInstruction(SPInstruction.SPType type, AggregateUnaryOperator auop, AggregateOperator aop, CPOperand in, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(type, auop, in, out, opcode, istr);
        this._aggtype = aggtype;
        this._aop = aop;
    }

    public static AggregateUnarySPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 3);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[3]);
        String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(opcode);
        Types.CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode);
        AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
        AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
        return new AggregateUnarySPInstruction(SPInstruction.SPType.AggregateUnary, aggun, aop, in1, out, aggtype, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        if (this.input1.getDataType() == Types.DataType.MATRIX) {
            this.processMatrixAggregate(ec);
        } else {
            this.processTensorAggregate(ec);
        }
    }

    private void processMatrixAggregate(ExecutionContext ec) {
        JavaPairRDD in;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        DataCharacteristics mc = sec.getDataCharacteristics(this.input1.getName());
        JavaPairRDD out = in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        if (this.getOpcode().equalsIgnoreCase("uaktrace")) {
            out = out.filter((Function)new FilterDiagMatrixBlocksFunction());
        }
        AggregateUnaryOperator auop = (AggregateUnaryOperator)this._optr;
        AggregateOperator aggop = this._aop;
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            if (auop.sparseSafe) {
                out = out.filter((Function)new FilterNonEmptyBlocksFunction());
            }
            JavaRDD out2 = out.map((Function)new RDDUAggFunction2(auop, mc.getBlocksize()));
            MatrixBlock out3 = RDDAggregateUtils.aggStable((JavaRDD<MatrixBlock>)out2, aggop);
            out3.dropLastRowsOrColumns(aggop.correction);
            sec.setMatrixOutput(this.output.getName(), out3);
        } else {
            if (this._aggtype == AggBinaryOp.SparkAggType.NONE) {
                out = out.mapValues((Function)new RDDUAggValueFunction(auop, mc.getBlocksize()));
            } else if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
                out = out.mapToPair((PairFunction)new RDDUAggFunction(auop, mc.getBlocksize()));
                out = RDDAggregateUtils.aggByKeyStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out, aggop, false);
                if (auop.aggOp.existsCorrection()) {
                    out = out.mapValues((Function)new AggregateDropCorrectionFunction(aggop));
                }
            }
            this.updateUnaryAggOutputDataCharacteristics(sec, auop.indexFn);
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
        }
    }

    private void processTensorAggregate(ExecutionContext ec) {
        JavaPairRDD in;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD out = in = sec.getBinaryTensorBlockRDDHandleForVariable(this.input1.getName());
        AggregateUnaryOperator auop = (AggregateUnaryOperator)this._optr;
        AggregateOperator aggop = this._aop;
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            JavaRDD out2 = out.map((Function)new RDDUTensorAggFunction2(auop));
            TensorBlock out3 = RDDAggregateUtils.aggStableTensor((JavaRDD<TensorBlock>)out2, aggop);
            TensorBlock out4 = new TensorBlock(out3.getValueType(), new int[]{1, 1});
            out4.set(0, 0, out3.get(0, 0));
            sec.setTensorOutput(this.output.getName(), out4);
        } else {
            if (this._aggtype == AggBinaryOp.SparkAggType.NONE) {
                out = out.mapValues((Function)new RDDUTensorAggValueFunction(auop));
            } else if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
                throw new DMLRuntimeException("Multi block spark aggregations are not supported for tensors yet.");
            }
            this.updateUnaryAggOutputDataCharacteristics(sec, auop.indexFn);
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
        }
    }

    private static class RDDUTensorAggValueFunction
    implements Function<TensorBlock, TensorBlock> {
        private static final long serialVersionUID = -968274963539513423L;
        private AggregateUnaryOperator _op = null;

        public RDDUTensorAggValueFunction(AggregateUnaryOperator op) {
            this._op = op;
        }

        public TensorBlock call(TensorBlock arg0) throws Exception {
            BasicTensorBlock blkOut = new BasicTensorBlock();
            arg0.getBasicTensor().aggregateUnaryOperations(this._op, blkOut);
            TensorBlock out = new TensorBlock(blkOut.getValueType(), new int[]{1, 1});
            out.set(0, 0, blkOut.get(0, 0));
            return out;
        }
    }

    private static class RDDUAggValueFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5352374590399929673L;
        private AggregateUnaryOperator _op = null;
        private int _blen = -1;
        private MatrixIndexes _ix = null;

        public RDDUAggValueFunction(AggregateUnaryOperator op, int blen) {
            this._op = op;
            this._blen = blen;
            this._blen = blen;
            this._ix = new MatrixIndexes(1L, 1L);
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            MatrixBlock blkOut = new MatrixBlock();
            arg0.aggregateUnaryOperations(this._op, blkOut, this._blen, this._ix, true);
            return blkOut;
        }
    }

    public static class RDDUTensorAggFunction2
    implements Function<Tuple2<TensorIndexes, TensorBlock>, TensorBlock> {
        private static final long serialVersionUID = -6258769067791011763L;
        private AggregateUnaryOperator _op = null;

        public RDDUTensorAggFunction2(AggregateUnaryOperator op) {
            this._op = op;
        }

        public TensorBlock call(Tuple2<TensorIndexes, TensorBlock> arg0) throws Exception {
            return new TensorBlock(((TensorBlock)arg0._2).getBasicTensor().aggregateUnaryOperations(this._op, new BasicTensorBlock()));
        }
    }

    public static class RDDUAggFunction2
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 2672082409287856038L;
        private AggregateUnaryOperator _op = null;
        private int _blen = -1;

        public RDDUAggFunction2(AggregateUnaryOperator op, int blen) {
            this._op = op;
            this._blen = blen;
            this._blen = blen;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            return ((MatrixBlock)arg0._2).aggregateUnaryOperations(this._op, new MatrixBlock(), this._blen, (MatrixIndexes)arg0._1());
        }
    }

    private static class RDDUAggFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2672082409287856038L;
        private AggregateUnaryOperator _op = null;
        private int _blen = -1;

        public RDDUAggFunction(AggregateUnaryOperator op, int blen) {
            this._op = op;
            this._blen = blen;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = new MatrixBlock();
            OperationsOnMatrixValues.performAggregateUnary(ixIn, blkIn, ixOut, blkOut, this._op, this._blen);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }
}

