/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.TemplateUtils;

public class CPlanOpRewriter {
    public CNodeTpl simplifyCPlan(CNodeTpl tpl) {
        if ((tpl = CPlanOpRewriter.rewriteRemoveOuterNeq0(tpl)) instanceof CNodeMultiAgg) {
            ArrayList<CNode> outputs = ((CNodeMultiAgg)tpl).getOutputs();
            for (int i = 0; i < outputs.size(); ++i) {
                outputs.set(i, CPlanOpRewriter.rSimplifyCNode(outputs.get(i)));
            }
        } else {
            tpl.setOutput(CPlanOpRewriter.rSimplifyCNode(tpl.getOutput()));
        }
        return tpl;
    }

    private static CNode rSimplifyCNode(CNode node) {
        for (int i = 0; i < node.getInput().size(); ++i) {
            node.getInput().set(i, CPlanOpRewriter.rSimplifyCNode(node.getInput().get(i)));
        }
        node = CPlanOpRewriter.rewriteRowCountNnz(node);
        node = CPlanOpRewriter.rewriteRowSumSq(node);
        node = CPlanOpRewriter.rewriteBinaryPow2(node);
        node = CPlanOpRewriter.rewriteBinaryPow2Vect(node);
        node = CPlanOpRewriter.rewriteBinaryMult2(node);
        node = CPlanOpRewriter.rewriteBinaryMult2Vect(node);
        return node;
    }

    private static CNode rewriteRowCountNnz(CNode node) {
        return TemplateUtils.isUnary(node, CNodeUnary.UnaryType.ROW_SUMS) && TemplateUtils.isBinary(node.getInput().get(0), CNodeBinary.BinType.VECT_NOTEQUAL_SCALAR) && node.getInput().get(0).getInput().get(1).isLiteral() && node.getInput().get(0).getInput().get(1).getVarname().equals("0") ? new CNodeUnary(node.getInput().get(0).getInput().get(0), CNodeUnary.UnaryType.ROW_COUNTNNZS) : node;
    }

    private static CNode rewriteRowSumSq(CNode node) {
        return TemplateUtils.isUnary(node, CNodeUnary.UnaryType.ROW_SUMS) && TemplateUtils.isBinary(node.getInput().get(0), CNodeBinary.BinType.VECT_POW_SCALAR) && node.getInput().get(0).getInput().get(1).isLiteral() && node.getInput().get(0).getInput().get(1).getVarname().equals("2") ? new CNodeUnary(node.getInput().get(0).getInput().get(0), CNodeUnary.UnaryType.ROW_SUMSQS) : node;
    }

    private static CNode rewriteBinaryPow2(CNode node) {
        return TemplateUtils.isBinary(node, CNodeBinary.BinType.POW) && node.getInput().get(1).isLiteral() && node.getInput().get(1).getVarname().equals("2") ? new CNodeUnary(node.getInput().get(0), CNodeUnary.UnaryType.POW2) : node;
    }

    private static CNode rewriteBinaryPow2Vect(CNode node) {
        return TemplateUtils.isBinary(node, CNodeBinary.BinType.VECT_POW_SCALAR) && node.getInput().get(1).isLiteral() && node.getInput().get(1).getVarname().equals("2") ? new CNodeUnary(node.getInput().get(0), CNodeUnary.UnaryType.VECT_POW2) : node;
    }

    private static CNode rewriteBinaryMult2(CNode node) {
        return TemplateUtils.isBinary(node, CNodeBinary.BinType.MULT) && node.getInput().get(1).isLiteral() && node.getInput().get(1).getVarname().equals("2") ? new CNodeUnary(node.getInput().get(0), CNodeUnary.UnaryType.MULT2) : node;
    }

    private static CNode rewriteBinaryMult2Vect(CNode node) {
        return TemplateUtils.isBinary(node, CNodeBinary.BinType.VECT_MULT) && node.getInput().get(1).isLiteral() && node.getInput().get(1).getVarname().equals("2") ? new CNodeUnary(node.getInput().get(0), CNodeUnary.UnaryType.VECT_MULT2) : node;
    }

    private static CNodeTpl rewriteRemoveOuterNeq0(CNodeTpl tpl) {
        if (tpl instanceof CNodeOuterProduct) {
            CPlanOpRewriter.rFindAndRemoveBinaryMS(tpl.getOutput(), (CNodeData)tpl.getInput().get(0), CNodeBinary.BinType.NOTEQUAL, "0", "1");
        }
        return tpl;
    }

    private static void rFindAndRemoveBinaryMS(CNode node, CNodeData mainInput, CNodeBinary.BinType type, String lit, String replace) {
        for (int i = 0; i < node.getInput().size(); ++i) {
            CNode tmp = node.getInput().get(i);
            if (TemplateUtils.isBinary(tmp, type) && tmp.getInput().get(1).isLiteral() && tmp.getInput().get(1).getVarname().equals(lit) && tmp.getInput().get(0) instanceof CNodeData && ((CNodeData)tmp.getInput().get(0)).getHopID() == mainInput.getHopID()) {
                CNodeData cnode = new CNodeData(new LiteralOp(replace));
                cnode.setLiteral(true);
                node.getInput().set(i, cnode);
                continue;
            }
            CPlanOpRewriter.rFindAndRemoveBinaryMS(tmp, mainInput, type, lit, replace);
        }
    }
}

