/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.utils;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.instructions.spark.data.RowMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

public class RDDAggregateUtils {
    private static final boolean TREE_AGGREGATION = false;

    public static MatrixBlock sumStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.sumStable(in.values());
    }

    public static MatrixBlock sumStable(JavaRDD<MatrixBlock> in) {
        return in.fold(new MatrixBlock(), new SumSingleBlockFunction(false));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.sumByKeyStable(in, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deepCopyCombiner) {
        return RDDAggregateUtils.sumByKeyStable(in, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions, boolean deepCopyCombiner) {
        JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = in.combineByKey(new CreateCorrBlockCombinerFunction(deepCopyCombiner), new MergeSumBlockValueFunction(deepCopyCombiner), new MergeSumBlockCombinerFunction(deepCopyCombiner), numPartitions);
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp.mapValues(new ExtractMatrixBlock());
        return out;
    }

    public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable(JavaPairRDD<MatrixIndexes, Double> in) {
        return RDDAggregateUtils.sumCellsByKeyStable(in, in.getNumPartitions());
    }

    public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable(JavaPairRDD<MatrixIndexes, Double> in, int numParts) {
        JavaPairRDD<MatrixIndexes, KahanObject> tmp = in.combineByKey(new CreateCellCombinerFunction(), new MergeSumCellValueFunction(), new MergeSumCellCombinerFunction(), numParts);
        JavaPairRDD<MatrixIndexes, Double> out = tmp.mapValues(new ExtractDoubleCell());
        return out;
    }

    public static MatrixBlock aggStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop) {
        return RDDAggregateUtils.aggStable(in.values(), aop);
    }

    public static MatrixBlock aggStable(JavaRDD<MatrixBlock> in, AggregateOperator aop) {
        return in.fold(new MatrixBlock(), new AggregateSingleBlockFunction(aop));
    }

    public static TensorBlock aggStableTensor(JavaPairRDD<TensorIndexes, TensorBlock> in, AggregateOperator aop) {
        return RDDAggregateUtils.aggStableTensor(in.values(), aop);
    }

    public static TensorBlock aggStableTensor(JavaRDD<TensorBlock> in, AggregateOperator aop) {
        return in.fold(new TensorBlock(), new AggregateSingleTensorBlockFunction(aop));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop) {
        return RDDAggregateUtils.aggByKeyStable(in, aop, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop, boolean deepCopyCombiner) {
        return RDDAggregateUtils.aggByKeyStable(in, aop, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop, int numPartitions, boolean deepCopyCombiner) {
        JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp = in.combineByKey(new CreateCorrBlockCombinerFunction(deepCopyCombiner), new MergeAggBlockValueFunction(aop), new MergeAggBlockCombinerFunction(aop), numPartitions);
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp.mapValues(new ExtractMatrixBlock());
        return out;
    }

    public static double max(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        AggregateUnaryOperator auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
        MatrixBlock tmp = RDDAggregateUtils.aggStable(in.map(new AggregateUnarySPInstruction.RDDUAggFunction2(auop, -1)), auop.aggOp);
        return tmp.quickGetValue(0, 0);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.mergeByKey(in, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deepCopyCombiner) {
        return RDDAggregateUtils.mergeByKey(in, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions, boolean deepCopyCombiner) {
        return in.combineByKey(new CreateBlockCombinerFunction(deepCopyCombiner), new MergeBlocksFunction(false), new MergeBlocksFunction(false), numPartitions);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeRowsByKey(JavaPairRDD<MatrixIndexes, RowMatrixBlock> in) {
        return in.combineByKey(new CreateRowBlockCombinerFunction(), new MergeRowBlockValueFunction(), new MergeBlocksFunction(false));
    }

    private static class MergeBlocksFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -8881019027250258850L;
        private boolean _deep = false;

        public MergeBlocksFunction() {
            this(true);
        }

        public MergeBlocksFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock b1, MatrixBlock b2) throws Exception {
            if (b1.getNumRows() != b2.getNumRows() || b1.getNumColumns() != b2.getNumColumns()) {
                throw new DMLRuntimeException("Mismatched block sizes for: " + b1.getNumRows() + " " + b1.getNumColumns() + " " + b2.getNumRows() + " " + b2.getNumColumns());
            }
            MatrixBlock ret = this._deep ? new MatrixBlock(b1) : b1;
            ret = ret.merge(b2, false, false, this._deep);
            ret.examSparsity();
            return ret;
        }
    }

    private static class AggregateSingleTensorBlockFunction
    implements Function2<TensorBlock, TensorBlock, TensorBlock> {
        private static final long serialVersionUID = 5665180309149919945L;
        private AggregateOperator _op = null;

        public AggregateSingleTensorBlockFunction(AggregateOperator op) {
            this._op = op;
        }

        public TensorBlock call(TensorBlock arg0, TensorBlock arg1) throws Exception {
            if (arg0.isEmpty()) {
                return arg1;
            }
            if (arg1.isEmpty()) {
                return arg0;
            }
            if (this._op.increOp.fn instanceof KahanPlus) {
                this._op = new AggregateOperator(0.0, Plus.getPlusFnObject());
            }
            arg0.getBasicTensor().incrementalAggregate(this._op, arg1.getBasicTensor());
            return arg0;
        }
    }

    private static class AggregateSingleBlockFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -3672377410407066396L;
        private AggregateOperator _op = null;
        private MatrixBlock _corr = null;

        public AggregateSingleBlockFunction(AggregateOperator op) {
            this._op = op;
        }

        public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg0.getNumRows() == 0 && arg0.getNumColumns() == 0) {
                arg0.copy(arg1);
                return arg0;
            }
            if (arg1.getNumRows() == 0 && arg1.getNumColumns() == 0) {
                return arg0;
            }
            if (this._op.sparseSafe && arg0.isEmpty() | arg1.isEmpty()) {
                return arg1.isEmpty() ? arg0 : arg1;
            }
            if (this._op.existsCorrection() && this._corr == null) {
                this._corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(arg0, this._op.existsCorrection() ? this._corr : null, arg1, this._op, true);
            return arg0;
        }
    }

    private static class SumSingleBlockFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1737038715965862222L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.NONE);
        private MatrixBlock _corr = null;
        private boolean _deep = false;

        public SumSingleBlockFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) {
                arg0.copy(arg1);
                return arg0;
            }
            if (arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0) {
                return arg0;
            }
            if (this._corr == null) {
                this._corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
            }
            MatrixBlock out = this._deep ? new MatrixBlock(arg0) : arg0;
            OperationsOnMatrixValues.incrementalAggregation(out, this._corr, arg1, this._op, false);
            return out;
        }
    }

    private static class ExtractDoubleCell
    implements Function<KahanObject, Double> {
        private static final long serialVersionUID = -2873241816558275742L;

        private ExtractDoubleCell() {
        }

        public Double call(KahanObject arg0) throws Exception {
            return arg0._sum;
        }
    }

    private static class ExtractMatrixBlock
    implements Function<CorrMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5242158678070843495L;

        private ExtractMatrixBlock() {
        }

        public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
            arg0.getValue().examSparsity();
            return arg0.getValue();
        }
    }

    private static class MergeAggBlockCombinerFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 4803711632648880797L;
        private AggregateOperator _op = null;

        public MergeAggBlockCombinerFunction(AggregateOperator aop) {
            this._op = aop;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            MatrixBlock value1 = arg0.getValue();
            MatrixBlock value2 = arg1.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null && this._op.existsCorrection()) {
                MatrixBlock matrixBlock = corr = arg1.getCorrection() != null ? arg1.getCorrection() : new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false);
            }
            if (this._op.existsCorrection()) {
                OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value1, null, value2, this._op, true);
            }
            return new CorrMatrixBlock(value1, corr);
        }
    }

    private static class MergeAggBlockValueFunction
    implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 389422125491172011L;
        private AggregateOperator _op = null;

        public MergeAggBlockValueFunction(AggregateOperator aop) {
            this._op = aop;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception {
            MatrixBlock value = arg0.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null && this._op.existsCorrection()) {
                corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            if (this._op.existsCorrection()) {
                OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value, null, arg1, this._op, true);
            }
            return new CorrMatrixBlock(value, corr);
        }
    }

    private static class MergeSumCellCombinerFunction
    implements Function2<KahanObject, KahanObject, KahanObject> {
        private static final long serialVersionUID = 8726716909849119657L;

        private MergeSumCellCombinerFunction() {
        }

        public KahanObject call(KahanObject arg0, KahanObject arg1) throws Exception {
            KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
            kplus.execute2(arg0, arg1._sum);
            return arg0;
        }
    }

    private static class MergeSumCellValueFunction
    implements Function2<KahanObject, Double, KahanObject> {
        private static final long serialVersionUID = 468335171573184825L;

        private MergeSumCellValueFunction() {
        }

        public KahanObject call(KahanObject arg0, Double arg1) throws Exception {
            KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
            kplus.execute2(arg0, arg1);
            return arg0;
        }
    }

    private static class CreateCellCombinerFunction
    implements Function<Double, KahanObject> {
        private static final long serialVersionUID = 3697505233057172994L;

        private CreateCellCombinerFunction() {
        }

        public KahanObject call(Double arg0) throws Exception {
            return new KahanObject(arg0, 0.0);
        }
    }

    private static class MergeRowBlockValueFunction
    implements Function2<MatrixBlock, RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -803689998683298516L;

        private MergeRowBlockValueFunction() {
        }

        public MatrixBlock call(MatrixBlock arg0, RowMatrixBlock arg1) throws Exception {
            MatrixBlock row = arg1.getValue();
            MatrixBlock out = arg0;
            out.copy(arg1.getRow(), arg1.getRow(), 0, row.getNumColumns() - 1, row, true);
            out.examSparsity();
            return out;
        }
    }

    private static class CreateRowBlockCombinerFunction
    implements Function<RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 2866598914232118425L;

        private CreateRowBlockCombinerFunction() {
        }

        public MatrixBlock call(RowMatrixBlock arg0) throws Exception {
            MatrixBlock row = arg0.getValue();
            MatrixBlock out = new MatrixBlock(arg0.getLen(), row.getNumColumns(), true);
            out.copy(arg0.getRow(), arg0.getRow(), 0, row.getNumColumns() - 1, row, false);
            out.setNonZeros(row.getNonZeros());
            out.examSparsity();
            return out;
        }
    }

    private static class CreateBlockCombinerFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1987501624176848292L;
        private final boolean _deep;

        public CreateBlockCombinerFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            return this._deep ? new MatrixBlock(arg0) : arg0;
        }
    }

    private static class MergeSumBlockCombinerFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 7664941774566119853L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.NONE);
        private final boolean _deep;

        public MergeSumBlockCombinerFunction(boolean deep) {
            this._deep = deep;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            MatrixBlock value1 = arg0.getValue();
            MatrixBlock value2 = arg1.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null) {
                corr = arg1.getCorrection() != null ? arg1.getCorrection() : (value2.isEmptyBlock(false) || !this._deep && value1.isEmptyBlock(false) ? null : new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false));
            }
            OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, this._op, false, this._deep);
            return arg0.set(value1, corr);
        }
    }

    private static class MergeSumBlockValueFunction
    implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 3703543699467085539L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.NONE);
        private final boolean _deep;

        public MergeSumBlockValueFunction(boolean deep) {
            this._deep = deep;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg1.isEmptyBlock(false)) {
                return arg0;
            }
            MatrixBlock value = arg0.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null && !arg1.isEmptyBlock(false)) {
                corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, this._op, false, this._deep);
            return arg0.set(value, corr);
        }
    }

    private static class CreateCorrBlockCombinerFunction
    implements Function<MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = -3666451526776017343L;
        private final boolean _deep;

        public CreateCorrBlockCombinerFunction(boolean deep) {
            this._deep = deep;
        }

        public CorrMatrixBlock call(MatrixBlock arg0) throws Exception {
            return new CorrMatrixBlock(this._deep ? new MatrixBlock(arg0) : arg0);
        }
    }
}

