/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.compile.linearization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.linearization.IDagLinearizer;

public class LinearizerCostBased
extends IDagLinearizer {
    @Override
    public List<Lop> linearize(List<Lop> v) {
        ArrayList<Lop> removedLeaves = new ArrayList<Lop>();
        ArrayList<Lop> removedRoots = new ArrayList<Lop>();
        HashMap<Long, ArrayList<Lop>> removedInputs = new HashMap<Long, ArrayList<Lop>>();
        HashMap<Long, ArrayList<Lop>> removedOutputs = new HashMap<Long, ArrayList<Lop>>();
        LinearizerCostBased.simplifyDag(v, removedLeaves, removedRoots, removedInputs, removedOutputs);
        List<Lop> leafNodes = v.stream().filter(l -> l.getInputs().isEmpty()).collect(Collectors.toList());
        ArrayList<Order> finalOrders = new ArrayList<Order>();
        for (Lop leaf : leafNodes) {
            LinearizerCostBased.generateOrders(leaf, leafNodes, finalOrders, v.size());
        }
        int randInd = (int)(Math.random() * (double)finalOrders.size());
        List<Lop> best = ((Order)finalOrders.get(randInd)).getOrder();
        LinearizerCostBased.addRemovedNodes(best, removedLeaves, removedRoots, removedInputs, removedOutputs);
        return best;
    }

    private static void generateOrders(Lop leaf, List<Lop> leafNodes, List<Order> finalOrders, int count) {
        Stack<Order> stack = new Stack<Order>();
        stack.push(new Order(leaf));
        while (!stack.isEmpty()) {
            Order partialOrder = (Order)stack.pop();
            if (partialOrder.size() == count) {
                finalOrders.add(partialOrder);
                continue;
            }
            ArrayList<Lop> distinctOutputs = new ArrayList<Lop>();
            for (Lop lop : partialOrder.getOrder()) {
                for (Lop out : lop.getOutputs()) {
                    if (out.isVisited() || !LinearizerCostBased.allInputsLinearized(out, partialOrder) || partialOrder.contains(out)) continue;
                    out.setVisited();
                    distinctOutputs.add(out);
                }
            }
            for (Lop out : distinctOutputs) {
                out.resetVisitStatus();
                stack.push(LinearizerCostBased.copyAndAdd(partialOrder, out, true));
            }
            for (Lop otherLeaf : leafNodes) {
                if (partialOrder.contains(otherLeaf)) continue;
                stack.push(LinearizerCostBased.copyAndAdd(partialOrder, otherLeaf, false));
            }
        }
    }

    private static boolean allInputsLinearized(Lop lop, Order partialOrder) {
        List<Lop> order = partialOrder.getOrder();
        for (Lop input : lop.getInputs()) {
            if (order.contains(input)) continue;
            return false;
        }
        return true;
    }

    private static Order copyAndAdd(Order partialOrder, Lop node, boolean allInputsLinearized) {
        Order newEntry = new Order(partialOrder);
        newEntry.addOperator(node, allInputsLinearized);
        return newEntry;
    }

    private static void simplifyDag(List<Lop> lops, List<Lop> removedLeaves, List<Lop> removedRoots, HashMap<Long, ArrayList<Lop>> removedInputs, HashMap<Long, ArrayList<Lop>> removedOutputs) {
        for (Lop lop : lops) {
            if (lop.getInputs().isEmpty() && (lop instanceof Data && ((Data)lop).isTransientRead() || lop.getDataType() == Types.DataType.SCALAR)) {
                removedLeaves.add(lop);
                for (Lop out : lop.getOutputs()) {
                    removedInputs.putIfAbsent(out.getID(), new ArrayList<Lop>(out.getInputs()));
                    out.removeInput(lop);
                }
            }
            if (!lop.getOutputs().isEmpty() || !(lop instanceof Data) || !((Data)lop).isTransientWrite()) continue;
            removedRoots.add(lop);
            for (Lop in : lop.getInputs()) {
                removedOutputs.putIfAbsent(in.getID(), new ArrayList<Lop>(in.getOutputs()));
                in.removeOutput(lop);
            }
        }
        lops.removeAll(removedLeaves);
        lops.removeAll(removedRoots);
    }

    private static void addRemovedNodes(List<Lop> lops, List<Lop> removedLeaves, List<Lop> removedRoots, HashMap<Long, ArrayList<Lop>> removedInputs, HashMap<Long, ArrayList<Lop>> removedOutputs) {
        for (Lop leaf : removedLeaves) {
            leaf.getOutputs().forEach(out -> out.replaceAllInputs((ArrayList)removedInputs.get(out.getID())));
        }
        lops.addAll(0, removedLeaves);
        for (Lop root : removedRoots) {
            root.getInputs().forEach(in -> in.replaceAllOutputs((ArrayList)removedOutputs.get(in.getID())));
        }
        lops.addAll(removedRoots);
    }

    private static class Order {
        private List<Lop> _order;
        private double _pinnedMemEstimate;
        private double _bufferpoolEstimate;
        private int _numEvictions;
        private double _computeCost;

        public Order(List<Lop> lops, double pin, double bp, double comp) {
            this._order = new ArrayList<Lop>(lops);
            this._pinnedMemEstimate = pin;
            this._bufferpoolEstimate = bp;
            this._numEvictions = 0;
            this._computeCost = comp;
        }

        public Order(Lop lop) {
            this(Arrays.asList(lop), lop.getOutputMemoryEstimate(), 0.0, lop.getComputeEstimate());
        }

        public Order(Order that) {
            this._order = that.getOrder();
            this._pinnedMemEstimate = that._pinnedMemEstimate;
            this._bufferpoolEstimate = that._bufferpoolEstimate;
            this._numEvictions = that._numEvictions;
            this._computeCost = that._computeCost;
        }

        public void addOperator(Lop lop, boolean allInputsLinearized) {
            this._order.add(lop);
            this._computeCost += lop.getComputeEstimate();
            this._bufferpoolEstimate += lop.getOutputMemoryEstimate();
            if (allInputsLinearized) {
                lop.getInputs().forEach(in -> this._bufferpoolEstimate -= in.getOutputMemoryEstimate());
                double d = this._bufferpoolEstimate = this._bufferpoolEstimate < 0.0 ? 0.0 : this._bufferpoolEstimate;
            }
            if (this._bufferpoolEstimate > (double)OptimizerUtils.getBufferPoolLimit()) {
                ++this._numEvictions;
            }
            this._pinnedMemEstimate = lop.getTotalMemoryEstimate();
        }

        protected List<Lop> getOrder() {
            return this._order;
        }

        protected double getComputeCost() {
            return this._computeCost;
        }

        protected boolean contains(Lop lop) {
            return this._order.contains(lop);
        }

        protected int size() {
            return this._order.size();
        }
    }
}

