/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysds.runtime.functionobjects.LessThan;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.NotEquals;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibRelationalOp {
    private static ThreadLocal<MatrixBlock> memPool = new ThreadLocal<MatrixBlock>(){

        @Override
        protected MatrixBlock initialValue() {
            return null;
        }
    };

    protected static boolean isValidForRelationalOperation(ScalarOperator sop, CompressedMatrixBlock m1) {
        return m1.isOverlapping() && (sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan || sop.fn instanceof GreaterThanEquals || sop.fn instanceof Equals || sop.fn instanceof NotEquals);
    }

    public static MatrixBlock overlappingRelativeRelationalOperation(ScalarOperator sop, CompressedMatrixBlock m1) {
        List<AColGroup> colGroups = m1.getColGroups();
        boolean less = (sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals) && sop instanceof LeftScalarOperator || sop instanceof RightScalarOperator && (sop.fn instanceof GreaterThan || sop.fn instanceof GreaterThanEquals);
        double v = sop.getConstant();
        MinMaxGroup[] minMax = new MinMaxGroup[colGroups.size()];
        double maxS = 0.0;
        double minS = 0.0;
        int id = 0;
        for (AColGroup grp : colGroups) {
            double minG = grp.getMin();
            double maxG = grp.getMax();
            minS += minG;
            maxS += maxG;
            minMax[id++] = new MinMaxGroup(minG, maxG, grp);
        }
        if (v < minS || v > maxS) {
            if (sop.fn instanceof Equals) {
                return CLALibRelationalOp.makeConstZero(m1.getNumRows(), m1.getNumColumns());
            }
            if (sop.fn instanceof NotEquals) {
                return CLALibRelationalOp.makeConstOne(m1.getNumRows(), m1.getNumColumns());
            }
            if (less) {
                if (v < minS || (sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan) && v <= minS) {
                    return CLALibRelationalOp.makeConstOne(m1.getNumRows(), m1.getNumColumns());
                }
                return CLALibRelationalOp.makeConstZero(m1.getNumRows(), m1.getNumColumns());
            }
            if (v > minS || (sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan) && v >= minS) {
                return CLALibRelationalOp.makeConstOne(m1.getNumRows(), m1.getNumColumns());
            }
            return CLALibRelationalOp.makeConstZero(m1.getNumRows(), m1.getNumColumns());
        }
        return CLALibRelationalOp.processNonConstant(sop, minMax, minS, maxS, m1.getNumRows(), m1.getNumColumns(), less);
    }

    private static MatrixBlock makeConstOne(int rows, int cols) {
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
        int[] colIndexes = new int[cols];
        for (int i = 0; i < colIndexes.length; ++i) {
            colIndexes[i] = i;
        }
        double[] values = new double[cols];
        Arrays.fill(values, 1.0);
        newColGroups.add(new ColGroupConst(colIndexes, rows, new Dictionary(values)));
        CompressedMatrixBlock ret = new CompressedMatrixBlock(rows, cols);
        ret.allocateColGroupList(newColGroups);
        ret.setNonZeros(cols * rows);
        ret.setOverlapping(false);
        return ret;
    }

    private static MatrixBlock makeConstZero(int rows, int cols) {
        MatrixBlock sb = new MatrixBlock(rows, cols, true, 0L);
        return sb;
    }

    private static MatrixBlock processNonConstant(ScalarOperator sop, MinMaxGroup[] minMax, double minS, double maxS, int rows, int cols, boolean less) {
        MatrixBlock res = new MatrixBlock(rows, cols, true, 0L).allocateBlock();
        int k = OptimizerUtils.getConstrainedNumThreads(-1);
        int outRows = rows;
        long nnz = 0L;
        if (k == 1) {
            int b = 65535 / cols;
            int blkz = outRows < b ? outRows : b;
            MatrixBlock tmp = new MatrixBlock(blkz, cols, false, -1L).allocateBlock();
            int i = 0;
            while (i * blkz < outRows) {
                for (MinMaxGroup mmg : minMax) {
                    mmg.g.decompressToBlockUnSafe(tmp, i * blkz, Math.min((i + 1) * blkz, rows), 0);
                }
                for (int row = 0; row < blkz && row < rows - i * blkz; ++row) {
                    int off = row + i * blkz;
                    for (int col = 0; col < cols; ++col) {
                        res.quickSetValue(off, col, sop.executeScalar(tmp.quickGetValue(row, col)));
                        if (res.quickGetValue(off, col) == 0.0) continue;
                        ++nnz;
                    }
                }
                ++i;
            }
            tmp.reset();
            res.setNonZeros(nnz);
        } else {
            int blkz = Short.MAX_VALUE;
            ExecutorService pool = CommonThreadPool.get(k);
            ArrayList<RelationalTask> tasks = new ArrayList<RelationalTask>();
            try {
                int i = 0;
                while (i * Short.MAX_VALUE < outRows) {
                    RelationalTask rt = new RelationalTask(minMax, i, Short.MAX_VALUE, res, rows, cols, sop);
                    tasks.add(rt);
                    ++i;
                }
                List futures = pool.invokeAll(tasks);
                pool.shutdown();
                for (Future f : futures) {
                    f.get();
                }
            }
            catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
        memPool.remove();
        return res;
    }

    private static class RelationalTask
    implements Callable<Object> {
        private final MinMaxGroup[] _minMax;
        private final int _i;
        private final int _blkz;
        private final MatrixBlock _res;
        private final int _rows;
        private final int _cols;
        private final ScalarOperator _sop;

        protected RelationalTask(MinMaxGroup[] minMax, int i, int blkz, MatrixBlock res, int rows, int cols, ScalarOperator sop) {
            this._minMax = minMax;
            this._i = i;
            this._blkz = blkz;
            this._res = res;
            this._rows = rows;
            this._cols = cols;
            this._sop = sop;
        }

        @Override
        public Object call() {
            MatrixBlock tmp = (MatrixBlock)memPool.get();
            if (tmp == null) {
                memPool.set(new MatrixBlock(this._blkz, this._cols, false, -1L).allocateBlock());
                tmp = (MatrixBlock)memPool.get();
            } else {
                tmp = (MatrixBlock)memPool.get();
                tmp.reset(this._blkz, this._cols, false, -1L);
            }
            for (MinMaxGroup mmg : this._minMax) {
                if (mmg.g.getNumberNonZeros() == 0L) continue;
                mmg.g.decompressToBlockUnSafe(tmp, this._i * this._blkz, Math.min((this._i + 1) * this._blkz, mmg.g.getNumRows()), 0);
            }
            int row = 0;
            int off = this._i * this._blkz;
            while (row < this._blkz && row < this._rows - this._i * this._blkz) {
                for (int col = 0; col < this._cols; ++col) {
                    this._res.appendValue(off, col, this._sop.executeScalar(tmp.quickGetValue(row, col)));
                }
                ++row;
                ++off;
            }
            return null;
        }
    }

    protected static class MinMaxGroup
    implements Comparable<MinMaxGroup> {
        double min;
        double max;
        AColGroup g;
        double[] values;

        public MinMaxGroup(double min, double max, AColGroup g) {
            this.min = min;
            this.max = max;
            this.g = g;
            this.values = g.getValues();
        }

        @Override
        public int compareTo(MinMaxGroup o) {
            double t = this.max - this.min;
            double ot = o.max - o.min;
            return Double.compare(t, ot);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("MMG: ");
            sb.append("[" + this.min + "," + this.max + "]");
            sb.append(" " + this.g.getClass().getSimpleName());
            return sb.toString();
        }
    }
}

