/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.cocode.CoCoderFactory;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.ICostEstimate;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorFactory;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.compress.utils.DblArrayIntListHashMap;
import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.DMLCompressionStatistics;

public class CompressedMatrixBlockFactory {
    private static final Log LOG = LogFactory.getLog((String)CompressedMatrixBlockFactory.class.getName());
    private final Timing time = new Timing(true);
    private final CompressionStatistics _stats = new CompressionStatistics();
    private final int k;
    private final CompressionSettings compSettings;
    private final ICostEstimate costEstimator;
    private double lastPhase;
    private MatrixBlock mb;
    private CompressedMatrixBlock res;
    private int phase = 0;
    private CompressedSizeInfo compressionGroups;

    private CompressedMatrixBlockFactory(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, ICostEstimate costEstimator) {
        this(mb, k, compSettings.create(), costEstimator);
    }

    private CompressedMatrixBlockFactory(MatrixBlock mb, int k, CompressionSettings compSettings, ICostEstimate costEstimator) {
        this.mb = mb;
        this.k = k;
        this.compSettings = compSettings;
        this.costEstimator = costEstimator;
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, WTreeRoot root) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder(), root);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, CompressionSettingsBuilder customSettings) {
        return CompressedMatrixBlockFactory.compress(mb, 1, customSettings, (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, WTreeRoot root) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), root);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, ICostEstimate costEstimator) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder(), costEstimator);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings) {
        return CompressedMatrixBlockFactory.compress(mb, k, compSettings, (WTreeRoot)null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, WTreeRoot root) {
        ICostEstimate ice;
        CompressionSettings cs = compSettings.create();
        if (root == null) {
            ice = CostEstimatorFactory.create(cs, null, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        } else {
            CostEstimatorBuilder csb = new CostEstimatorBuilder(root);
            ice = CostEstimatorFactory.create(cs, csb, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        }
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, cs, ice);
        return cmbf.compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, CostEstimatorBuilder csb) {
        CompressionSettings cs = compSettings.create();
        ICostEstimate ice = CostEstimatorFactory.create(cs, csb, mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, cs, ice);
        return cmbf.compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettingsBuilder compSettings, ICostEstimate costEstimator) {
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, compSettings, costEstimator);
        return cmbf.compressMatrix();
    }

    public static CompressedMatrixBlock genUncompressedCompressedMatrixBlock(MatrixBlock mb) {
        CompressedMatrixBlock ret = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns());
        ColGroupUncompressed cg = new ColGroupUncompressed(mb);
        ret.allocateColGroup(cg);
        ret.setNonZeros(mb.getNonZeros());
        return ret;
    }

    public static CompressedMatrixBlock createConstant(int numRows, int numCols, double value) {
        CompressedMatrixBlock block = new CompressedMatrixBlock(numRows, numCols);
        AColGroup cg = ColGroupFactory.genColGroupConst(numCols, value);
        block.allocateColGroup(cg);
        block.recomputeNonZeros();
        if (block.getNumRows() == 0 || block.getNumColumns() == 0) {
            throw new DMLCompressionException("Invalid size of allocated constant compressed matrix block");
        }
        return block;
    }

    private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
        if (this.mb instanceof CompressedMatrixBlock) {
            LOG.info((Object)"MatrixBlock already compressed or is Empty");
            return new ImmutablePair((Object)this.mb, null);
        }
        if (this.mb.isEmpty()) {
            LOG.info((Object)"Empty input to compress, returning a compressed Matrix block with empty column group");
            CompressedMatrixBlock ret = new CompressedMatrixBlock(this.mb.getNumRows(), this.mb.getNumColumns());
            ColGroupEmpty cg = ColGroupEmpty.generate(this.mb.getNumColumns());
            ret.allocateColGroup(cg);
            ret.setNonZeros(0L);
            return new ImmutablePair((Object)ret, null);
        }
        this._stats.denseSize = MatrixBlock.estimateSizeInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), 1.0);
        this._stats.originalSize = this.mb.getInMemorySize();
        this.res = new CompressedMatrixBlock(this.mb);
        this.looksLikeOneHot();
        if (this.compressionGroups == null) {
            this.classifyPhase();
            if (this.compressionGroups == null) {
                return this.abortCompression();
            }
        }
        this.transposePhase();
        this.compressPhase();
        this.finalizePhase();
        if (this.res == null) {
            return this.abortCompression();
        }
        if (this.compSettings.isInSparkInstruction) {
            this.res.clearSoftReferenceToDecompressed();
        }
        return new ImmutablePair((Object)this.res, (Object)this._stats);
    }

    private void classifyPhase() {
        boolean isValidForMemoryBasedCompression;
        double sizeToCompare;
        CompressedSizeEstimator sizeEstimator = CompressedSizeEstimatorFactory.getSizeEstimator(this.mb, this.compSettings, this.k);
        if (this.compSettings.transposed) {
            this.mb = sizeEstimator.getData();
        }
        this.compressionGroups = sizeEstimator.computeCompressedSizeInfos(this.k);
        this._stats.estimatedSizeCols = this.compressionGroups.memoryEstimate();
        this.logPhase();
        double d = sizeToCompare = this.costEstimator instanceof ComputationCostEstimator && ((ComputationCostEstimator)this.costEstimator).isDense() ? (double)this._stats.denseSize : (double)this._stats.originalSize;
        boolean isValidForComputeBasedCompression = this.isComputeBasedCompression() && this.compSettings.minimumCompressionRatio != 1.0 ? (double)this._stats.estimatedSizeCols * this.compSettings.minimumCompressionRatio < sizeToCompare : true;
        boolean bl = isValidForMemoryBasedCompression = (double)this._stats.estimatedSizeCols * this.compSettings.minimumCompressionRatio < sizeToCompare;
        if (isValidForComputeBasedCompression || isValidForMemoryBasedCompression) {
            int nCols;
            int n = nCols = this.compSettings.transposed ? this.mb.getNumRows() : this.mb.getNumColumns();
            if (nCols > 1) {
                this.coCodePhase(sizeEstimator, this.costEstimator, sizeToCompare);
            } else {
                this.logPhase();
            }
        } else {
            LOG.info((Object)("Estimated Size of singleColGroups: " + this._stats.estimatedSizeCols));
            LOG.info((Object)("Original size                    : " + this._stats.originalSize));
        }
    }

    private boolean isComputeBasedCompression() {
        return this.costEstimator instanceof ComputationCostEstimator;
    }

    private void coCodePhase(CompressedSizeEstimator sizeEstimator, ICostEstimate costEstimator, double sizeToCompare) {
        this.compressionGroups = CoCoderFactory.findCoCodesByPartitioning(sizeEstimator, this.compressionGroups, this.k, costEstimator, this.compSettings);
        this._stats.estimatedSizeCoCoded = this.compressionGroups.memoryEstimate();
        this.logPhase();
        if (this.isComputeBasedCompression() && (double)this._stats.estimatedSizeCoCoded * this.compSettings.minimumCompressionRatio > sizeToCompare) {
            this.compressionGroups = null;
            LOG.info((Object)("Aborting compression because the cocoded size : " + this._stats.estimatedSizeCoCoded));
            LOG.info((Object)("Vs original size                              : " + this._stats.originalSize));
        }
    }

    private void looksLikeOneHot() {
        int numColumns = this.mb.getNumColumns();
        int numRows = this.mb.getNumRows();
        long nnz = this.mb.getNonZeros();
        int colGroupSize = 100;
        if (nnz == (long)numRows && numColumns != 1) {
            boolean onlyOneValues = true;
            LOG.debug((Object)"Looks like one hot encoded.");
            if (this.mb.isInSparseFormat()) {
                SparseBlock sb = this.mb.getSparseBlock();
                for (double v : sb.get(0).values()) {
                    boolean bl = onlyOneValues = v == 1.0;
                    if (onlyOneValues) {
                        continue;
                    }
                    break;
                }
            } else {
                double[] vals = this.mb.getDenseBlock().values(0);
                for (int i = 0; i < Math.min(vals.length, 1000); ++i) {
                    double v = vals[i];
                    boolean bl = onlyOneValues = v == 1.0 || v == 0.0;
                    if (onlyOneValues) {
                        continue;
                    }
                    break;
                }
            }
            if (onlyOneValues) {
                ArrayList<CompressedSizeInfoColGroup> ng = new ArrayList<CompressedSizeInfoColGroup>(numColumns / 100 + 1);
                for (int i = 0; i < numColumns; i += 100) {
                    int[] columnIds = new int[Math.min(100, numColumns - i)];
                    for (int j = 0; j < columnIds.length; ++j) {
                        columnIds[j] = i + j;
                    }
                    ng.add(new CompressedSizeInfoColGroup(columnIds, Math.min(numColumns, 100), numRows));
                }
                this.compressionGroups = new CompressedSizeInfo(ng);
                LOG.debug((Object)"Concluded that it probably is one hot encoded skipping analysis");
                this.phase += 2;
            }
        }
    }

    private void transposePhase() {
        if (!this.compSettings.transposed) {
            this.transposeHeuristics();
            if (this.compSettings.transposed) {
                boolean sparse = this.mb.isInSparseFormat();
                this.mb = LibMatrixReorg.transpose(this.mb, new MatrixBlock(this.mb.getNumColumns(), this.mb.getNumRows(), sparse), this.k, true);
            }
        }
        this.logPhase();
    }

    private void transposeHeuristics() {
        switch (this.compSettings.transposeInput) {
            case "true": {
                this.compSettings.transposed = true;
                break;
            }
            case "false": {
                this.compSettings.transposed = false;
                break;
            }
            default: {
                if (this.mb.isInSparseFormat()) {
                    boolean haveManyColumns = this.mb.getNumColumns() > 10000;
                    boolean isNnzLowAndVerySparse = this.mb.getNonZeros() < 1000L && this.mb.getSparsity() < 0.4;
                    boolean isAboveRowNumbers = this.mb.getNumRows() > 500000;
                    boolean isAboveThreadToColumnRatio = this.compressionGroups.getNumberColGroups() > this.mb.getNumColumns() / 4;
                    this.compSettings.transposed = haveManyColumns || isNnzLowAndVerySparse || isAboveRowNumbers && isAboveThreadToColumnRatio;
                    break;
                }
                this.compSettings.transposed = false;
            }
        }
    }

    private void compressPhase() {
        this.res.allocateColGroupList(ColGroupFactory.compressColGroups(this.mb, this.compressionGroups, this.compSettings, this.k));
        this._stats.compressedInitialSize = this.res.getInMemorySize();
        this.logPhase();
    }

    private void finalizePhase() {
        CLALibUtils.combineConstColumns(this.res);
        this.res.cleanupBlock(true, true);
        this._stats.size = this.res.getInMemorySize();
        double ratio = this._stats.getRatio();
        double denseRatio = this._stats.getDenseRatio();
        if (ratio < 1.0 && denseRatio < 100.0) {
            LOG.info((Object)("--dense size:        " + this._stats.denseSize));
            LOG.info((Object)("--original size:     " + this._stats.originalSize));
            LOG.info((Object)("--compressed size:   " + this._stats.size));
            LOG.info((Object)("--compression ratio: " + ratio));
            LOG.info((Object)"Abort block compression because compression ratio is less than 1.");
            this.res = null;
            this.setNextTimePhase(this.time.stop());
            DMLCompressionStatistics.addCompressionTime(this.getLastTimePhase(), this.phase);
            return;
        }
        this._stats.setColGroupsCounts(this.res.getColGroups());
        long oldNNZ = this.mb.getNonZeros();
        if (oldNNZ <= 0L) {
            this.res.setNonZeros(oldNNZ);
        } else {
            this.res.recomputeNonZeros();
        }
        this.logPhase();
    }

    private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
        LOG.warn((Object)("Compression aborted at phase: " + this.phase));
        if (this.compSettings.transposed) {
            LibMatrixReorg.transposeInPlace(this.mb, this.k);
        }
        return new ImmutablePair((Object)this.mb, (Object)this._stats);
    }

    private void logPhase() {
        this.setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(this.getLastTimePhase(), this.phase);
        if (LOG.isDebugEnabled()) {
            if (this.compSettings.isInSparkInstruction) {
                if (this.phase == 5) {
                    LOG.debug((Object)this._stats);
                }
            } else {
                switch (this.phase) {
                    case 0: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Classify  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--Individual Columns Estimated Compression: " + this._stats.estimatedSizeCols));
                        break;
                    }
                    case 1: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Grouping  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("Grouping using: " + (Object)((Object)this.compSettings.columnPartitioner)));
                        LOG.debug((Object)("Cost Calculated using: " + this.costEstimator));
                        LOG.debug((Object)("--Cocoded Columns estimated Compression:" + this._stats.estimatedSizeCoCoded));
                        if (this.compressionGroups.getInfo().size() < 1000) {
                            LOG.debug((Object)("--Cocoded Columns estimated nr distinct:" + this.compressionGroups.getEstimatedDistinct()));
                            LOG.debug((Object)("--Cocoded Columns nr columns           :" + this.compressionGroups.getNrColumnsString()));
                            break;
                        }
                        LOG.debug((Object)("--CoCoded produce many columns but the first says:\n" + this.compressionGroups.getInfo().get(0)));
                        break;
                    }
                    case 2: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Transpose : " + this.getLastTimePhase()));
                        LOG.debug((Object)("Did transpose: " + this.compSettings.transposed));
                        break;
                    }
                    case 3: {
                        LOG.debug((Object)("--compression phase " + this.phase + " Compress  : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--compression Hash collisions:(" + DblArrayIntListHashMap.hashMissCount + "," + DoubleCountHashMap.hashMissCount + ")"));
                        DblArrayIntListHashMap.hashMissCount = 0;
                        DoubleCountHashMap.hashMissCount = 0;
                        LOG.debug((Object)("--compressed initial actual size:" + this._stats.compressedInitialSize));
                        break;
                    }
                    case 4: {
                        LOG.debug((Object)("--num col groups: " + this.res.getColGroups().size()));
                        LOG.debug((Object)("--compression phase " + this.phase + " Cleanup   : " + this.getLastTimePhase()));
                        LOG.debug((Object)("--col groups types " + this._stats.getGroupsTypesString()));
                        LOG.debug((Object)("--col groups sizes " + this._stats.getGroupsSizesString()));
                        LOG.debug((Object)("--dense size:        " + this._stats.denseSize));
                        LOG.debug((Object)("--original size:     " + this._stats.originalSize));
                        LOG.debug((Object)("--compressed size:   " + this._stats.size));
                        LOG.debug((Object)("--compression ratio: " + this._stats.getRatio()));
                        LOG.debug((Object)("--Dense       ratio: " + this._stats.getDenseRatio()));
                        if (this.compressionGroups.getInfo().size() < 1000) {
                            int[] lengths = new int[this.res.getColGroups().size()];
                            int i = 0;
                            for (AColGroup colGroup : this.res.getColGroups()) {
                                lengths[i++] = colGroup.getNumValues();
                            }
                            LOG.debug((Object)("--compressed colGroup dictionary sizes: " + Arrays.toString(lengths)));
                            LOG.debug((Object)("--compressed colGroup nr columns      : " + CompressedMatrixBlockFactory.constructNrColumnString(this.res.getColGroups())));
                        }
                        if (!LOG.isTraceEnabled()) break;
                        for (AColGroup colGroup : this.res.getColGroups()) {
                            if (colGroup.estimateInMemorySize() < 1000L) {
                                LOG.trace((Object)colGroup);
                                continue;
                            }
                            LOG.trace((Object)("--colGroups type       : " + colGroup.getClass().getSimpleName() + " size: " + colGroup.estimateInMemorySize() + (colGroup instanceof ColGroupValue ? "  numValues :" + ((ColGroupValue)colGroup).getNumValues() : "") + "  colIndexes : " + Arrays.toString(colGroup.getColIndices())));
                        }
                        break;
                    }
                }
            }
        }
        ++this.phase;
    }

    private void setNextTimePhase(double time) {
        this.lastPhase = time;
    }

    private double getLastTimePhase() {
        return this.lastPhase;
    }

    private static String constructNrColumnString(List<AColGroup> cg) {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        sb.append(cg.get(0).getNumCols());
        for (int id = 1; id < cg.size(); ++id) {
            sb.append(", " + cg.get(id).getNumCols());
        }
        sb.append("]");
        return sb.toString();
    }
}

