/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.dl;

import java.util.HashSet;
import org.apache.sysml.api.dl.BatchNorm;
import org.apache.sysml.api.dl.Caffe2DML$;
import org.apache.sysml.api.dl.CaffeLayer;
import org.apache.sysml.api.dl.CaffeNetwork;
import org.apache.sysml.api.dl.CaffeSolver;
import org.apache.sysml.api.dl.DMLGenerator;
import org.apache.sysml.api.dl.DMLGenerator$;
import org.apache.sysml.api.dl.IsLossLayer;
import org.apache.sysml.api.dl.Utils$;
import org.apache.sysml.runtime.DMLRuntimeException;
import scala.Function0;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Seq;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public abstract class DMLGenerator$class {
    public static void reset(DMLGenerator $this) {
        $this.dmlScript().clear();
        $this.alreadyImported().clear();
        $this.numTabs_$eq(0);
    }

    public static StringBuilder tabDMLScript(DMLGenerator $this) {
        return $this.tabDMLScript($this.dmlScript(), $this.numTabs(), false);
    }

    public static StringBuilder tabDMLScript(DMLGenerator $this, boolean prependNewLine) {
        return $this.tabDMLScript($this.dmlScript(), $this.numTabs(), prependNewLine);
    }

    public static void source(DMLGenerator $this, CaffeNetwork net, CaffeSolver solver, String[] otherFiles) {
        $this.source($this.dmlScript(), $this.numTabs(), net, solver, otherFiles);
    }

    public static void ifBlock(DMLGenerator $this, String cond, Function0 op) {
        $this.tabDMLScript().append(new StringBuilder().append((Object)"if(").append((Object)cond).append((Object)") {\n").toString());
        $this.numTabs_$eq($this.numTabs() + 1);
        op.apply$mcV$sp();
        $this.numTabs_$eq($this.numTabs() - 1);
        $this.tabDMLScript().append("}\n");
    }

    public static void whileBlock(DMLGenerator $this, String cond, Function0 op) {
        $this.tabDMLScript().append(new StringBuilder().append((Object)"while(").append((Object)cond).append((Object)") {\n").toString());
        $this.numTabs_$eq($this.numTabs() + 1);
        op.apply$mcV$sp();
        $this.numTabs_$eq($this.numTabs() - 1);
        $this.tabDMLScript().append("}\n");
    }

    public static void forBlock(DMLGenerator $this, String iterVarName, String startVal, String endVal, String step, Function0 op) {
        $this.tabDMLScript().append(new StringBuilder().append((Object)"for(").append((Object)iterVarName).append((Object)" in seq(").append((Object)startVal).append((Object)",").append((Object)endVal).append((Object)",").append((Object)step).append((Object)")) {\n").toString());
        $this.numTabs_$eq($this.numTabs() + 1);
        op.apply$mcV$sp();
        $this.numTabs_$eq($this.numTabs() - 1);
        $this.tabDMLScript().append("}\n");
    }

    public static void forBlock(DMLGenerator $this, String iterVarName, String startVal, String endVal, Function0 op) {
        $this.tabDMLScript().append(new StringBuilder().append((Object)"for(").append((Object)iterVarName).append((Object)" in ").append((Object)startVal).append((Object)":").append((Object)endVal).append((Object)") {\n").toString());
        $this.numTabs_$eq($this.numTabs() + 1);
        op.apply$mcV$sp();
        $this.numTabs_$eq($this.numTabs() - 1);
        $this.tabDMLScript().append("}\n");
    }

    public static void parForBlock(DMLGenerator $this, String iterVarName, String startVal, String endVal, String step, String parforParameters, Function0 op) {
        StringBuilder stringBuilder = step.equals("1") ? $this.tabDMLScript().append(new StringBuilder().append((Object)"parfor(").append((Object)iterVarName).append((Object)" in ").append((Object)startVal).append((Object)":").append((Object)endVal).append((Object)parforParameters).append((Object)") {\n").toString()) : $this.tabDMLScript().append(new StringBuilder().append((Object)"parfor(").append((Object)iterVarName).append((Object)" in seq(").append((Object)startVal).append((Object)",").append((Object)endVal).append((Object)",").append((Object)step).append((Object)")").append((Object)parforParameters).append((Object)") {\n").toString());
        $this.numTabs_$eq($this.numTabs() + 1);
        op.apply$mcV$sp();
        $this.numTabs_$eq($this.numTabs() - 1);
        $this.tabDMLScript().append("}\n");
    }

    public static void printClassificationReport(DMLGenerator $this) {
        $this.ifBlock("debug", (Function0<BoxedUnit>)new Serializable($this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ DMLGenerator $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                this.$outer.assign(this.$outer.tabDMLScript(), "num_rows_error_measures", this.$outer.min("10", this.$outer.ncol("yb")));
                this.$outer.assign(this.$outer.tabDMLScript(), "error_measures", this.$outer.matrix("0", "num_rows_error_measures", "5"));
                this.$outer.forBlock("class_i", "1", "num_rows_error_measures", (Function0<BoxedUnit>)new Serializable(this){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ DMLGenerator$.anonfun.printClassificationReport.1 $outer;

                    public final void apply() {
                        this.apply$mcV$sp();
                    }

                    public void apply$mcV$sp() {
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "tp_plus_fp", "sum( (predicted_yb == class_i) )");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "tp_plus_fn", "sum( (true_yb == class_i) )");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "precision", "tp / tp_plus_fp");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "recall", "tp / tp_plus_fn");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "f1Score", "2*precision*recall / (precision+recall)");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "error_measures[class_i,1]", "class_i");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "error_measures[class_i,2]", "precision");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "error_measures[class_i,3]", "recall");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "error_measures[class_i,4]", "f1Score");
                        this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().assign(this.$outer.org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer().tabDMLScript(), "error_measures[class_i,5]", "tp_plus_fn");
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                    }
                });
                String dmlTab = "\\t";
                String header = new StringBuilder().append((Object)"class    ").append((Object)dmlTab).append((Object)"precision").append((Object)dmlTab).append((Object)"recall  ").append((Object)dmlTab).append((Object)"f1-score").append((Object)dmlTab).append((Object)"num_true_labels\\n").toString();
                String errorMeasures = new StringBuilder().append((Object)"toString(error_measures, decimal=7, sep=").append((Object)this.$outer.asDMLString(dmlTab)).append((Object)")").toString();
                this.$outer.tabDMLScript().append(this.$outer.print(this.$outer.dmlConcat((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])new String[]{this.$outer.asDMLString(header), errorMeasures}))));
            }

            public /* synthetic */ DMLGenerator org$apache$sysml$api$dl$DMLGenerator$$anonfun$$$outer() {
                return this.$outer;
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
    }

    public static void appendHeaders(DMLGenerator $this, CaffeNetwork net, CaffeSolver solver, boolean isTraining) {
        $this.source(net, solver, isTraining ? (String[])((Object[])new String[]{"l1_reg"}) : null);
        $this.source(net, solver, isTraining ? (String[])((Object[])new String[]{"l2_reg"}) : null);
        if (isTraining && Caffe2DML$.MODULE$.USE_NESTEROV_UDF()) {
            $this.tabDMLScript().append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n");
        }
    }

    public static void readMatrix(DMLGenerator $this, String varName, String cmdLineVar) {
        String pathVar = new StringBuilder().append((Object)varName).append((Object)"_path").toString();
        $this.assign($this.tabDMLScript(), pathVar, $this.ifdef(cmdLineVar));
        $this.assign($this.tabDMLScript(), varName, new StringBuilder().append((Object)"read(").append((Object)pathVar).append((Object)")").toString());
    }

    public static void readInputData(DMLGenerator $this, CaffeNetwork net, boolean isTraining, boolean performOneHotEncoding) {
        $this.readMatrix("X_full", "$X");
        if (isTraining) {
            $this.readMatrix("y_full", "$y");
            $this.tabDMLScript().append(Caffe2DML$.MODULE$.numImages()).append(" = nrow(y_full)\n");
            if (performOneHotEncoding) {
                $this.tabDMLScript().append("# Convert to one-hot encoding (Assumption: 1-based labels) \n");
                $this.tabDMLScript().append(new StringBuilder().append((Object)"y_full = table(seq(1,").append((Object)Caffe2DML$.MODULE$.numImages()).append((Object)",1), y_full, ").append((Object)Caffe2DML$.MODULE$.numImages()).append((Object)", ").append((Object)Utils$.MODULE$.numClasses(net)).append((Object)")\n").toString());
            }
        } else {
            $this.tabDMLScript().append(new StringBuilder().append((Object)Caffe2DML$.MODULE$.numImages()).append((Object)" = nrow(X_full)\n").toString());
        }
    }

    public static void initWeights(DMLGenerator $this, CaffeNetwork net, CaffeSolver solver, boolean readWeights) {
        $this.initWeights(net, solver, readWeights, new HashSet<String>());
    }

    public static void initWeights(DMLGenerator $this, CaffeNetwork net, CaffeSolver solver, boolean readWeights, HashSet layersToIgnore) {
        Object object;
        $this.tabDMLScript().append("weights = ifdef($weights, \" \")\n");
        $this.tabDMLScript().append("# Initialize the layers and solvers\n");
        net.getLayers().map((Function1)new Serializable($this, net){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ DMLGenerator $outer;
            private final CaffeNetwork net$2;

            public final void apply(String layer) {
                this.net$2.getCaffeLayer(layer).init(this.$outer.tabDMLScript());
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.net$2 = net$2;
            }
        }, List$.MODULE$.canBuildFrom());
        if (readWeights) {
            $this.tabDMLScript().append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n");
            List allLayers = (List)((List)net.getLayers().filter((Function1)new Serializable($this, layersToIgnore){
                public static final long serialVersionUID = 0L;
                private final HashSet layersToIgnore$1;

                public final boolean apply(String l) {
                    return !this.layersToIgnore$1.contains(l);
                }
                {
                    this.layersToIgnore$1 = layersToIgnore$1;
                }
            })).map((Function1)new Serializable($this, net){
                public static final long serialVersionUID = 0L;
                private final CaffeNetwork net$2;

                public final CaffeLayer apply(String x$1) {
                    return this.net$2.getCaffeLayer(x$1);
                }
                {
                    this.net$2 = net$2;
                }
            }, List$.MODULE$.canBuildFrom());
            ((List)allLayers.filter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(CaffeLayer x$2) {
                    return x$2.weight() != null;
                }
            })).map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ DMLGenerator $outer;

                public final StringBuilder apply(CaffeLayer l) {
                    return this.$outer.tabDMLScript().append(this.$outer.readWeight(l.weight(), new StringBuilder().append((Object)l.param().getName()).append((Object)"_weight.mtx").toString(), this.$outer.readWeight$default$3()));
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            }, List$.MODULE$.canBuildFrom());
            ((List)allLayers.filter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(CaffeLayer x$3) {
                    return x$3.extraWeight() != null;
                }
            })).map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ DMLGenerator $outer;

                public final StringBuilder apply(CaffeLayer l) {
                    return this.$outer.tabDMLScript().append(this.$outer.readWeight(l.extraWeight(), new StringBuilder().append((Object)l.param().getName()).append((Object)"_extra_weight.mtx").toString(), this.$outer.readWeight$default$3()));
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            }, List$.MODULE$.canBuildFrom());
            object = ((List)allLayers.filter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(CaffeLayer x$4) {
                    return x$4.bias() != null;
                }
            })).map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ DMLGenerator $outer;

                public final StringBuilder apply(CaffeLayer l) {
                    return this.$outer.tabDMLScript().append(this.$outer.readWeight(l.bias(), new StringBuilder().append((Object)l.param().getName()).append((Object)"_bias.mtx").toString(), this.$outer.readWeight$default$3()));
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            }, List$.MODULE$.canBuildFrom());
        } else {
            object = BoxedUnit.UNIT;
        }
        net.getLayers().map((Function1)new Serializable($this, net, solver){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ DMLGenerator $outer;
            private final CaffeNetwork net$2;
            private final CaffeSolver solver$1;

            public final void apply(String layer) {
                this.solver$1.init(this.$outer.tabDMLScript(), this.net$2.getCaffeLayer(layer));
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.net$2 = net$2;
                this.solver$1 = solver$1;
            }
        }, List$.MODULE$.canBuildFrom());
    }

    /*
     * WARNING - void declaration
     */
    public static List getLossLayers(DMLGenerator $this, CaffeNetwork net) {
        void var2_2;
        List lossLayers = (List)((List)net.getLayers().filter((Function1)new Serializable($this, net){
            public static final long serialVersionUID = 0L;
            private final CaffeNetwork net$3;

            public final boolean apply(String layer) {
                return this.net$3.getCaffeLayer(layer) instanceof IsLossLayer;
            }
            {
                this.net$3 = net$3;
            }
        })).map((Function1)new Serializable($this, net){
            public static final long serialVersionUID = 0L;
            private final CaffeNetwork net$3;

            public final IsLossLayer apply(String layer) {
                return (IsLossLayer)this.net$3.getCaffeLayer(layer);
            }
            {
                this.net$3 = net$3;
            }
        }, List$.MODULE$.canBuildFrom());
        if (lossLayers.length() != 1) {
            throw new DMLRuntimeException(new StringBuilder().append((Object)"Expected exactly one loss layer, but found ").append((Object)BoxesRunTime.boxToInteger((int)lossLayers.length())).append((Object)":").append(net.getLayers().filter((Function1)new Serializable($this, net){
                public static final long serialVersionUID = 0L;
                private final CaffeNetwork net$3;

                public final boolean apply(String layer) {
                    return this.net$3.getCaffeLayer(layer) instanceof IsLossLayer;
                }
                {
                    this.net$3 = net$3;
                }
            })).toString());
        }
        return var2_2;
    }

    public static void updateMeanVarianceForBatchNorm(DMLGenerator $this, CaffeNetwork net, boolean value) {
        ((List)net.getLayers().filter((Function1)new Serializable($this, net){
            public static final long serialVersionUID = 0L;
            private final CaffeNetwork net$4;

            public final boolean apply(String x$5) {
                return this.net$4.getCaffeLayer(x$5) instanceof BatchNorm;
            }
            {
                this.net$4 = net$4;
            }
        })).map((Function1)new Serializable($this, net, value){
            public static final long serialVersionUID = 0L;
            private final CaffeNetwork net$4;
            private final boolean value$1;

            public final void apply(String x$6) {
                ((BatchNorm)this.net$4.getCaffeLayer(x$6)).update_mean_var_$eq(this.value$1);
            }
            {
                this.net$4 = net$4;
                this.value$1 = value$1;
            }
        }, List$.MODULE$.canBuildFrom());
    }

    public static void $init$(DMLGenerator $this) {
        $this.dmlScript_$eq(new StringBuilder());
        $this.numTabs_$eq(0);
    }
}

