/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.dl;

import caffe.Caffe;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.Message;
import com.google.protobuf.TextFormat;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.sysml.api.dl.AdaGrad;
import org.apache.sysml.api.dl.Adam;
import org.apache.sysml.api.dl.BatchNorm;
import org.apache.sysml.api.dl.Caffe2DML$;
import org.apache.sysml.api.dl.CaffeLayer;
import org.apache.sysml.api.dl.CaffeNetwork;
import org.apache.sysml.api.dl.CaffeSolver;
import org.apache.sysml.api.dl.Convolution;
import org.apache.sysml.api.dl.Data;
import org.apache.sysml.api.dl.DeConvolution;
import org.apache.sysml.api.dl.InnerProduct;
import org.apache.sysml.api.dl.Nesterov;
import org.apache.sysml.api.dl.SGD;
import org.apache.sysml.api.dl.Utils;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import scala.Function1;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.JavaConversions$;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

public final class Utils$ {
    public static final Utils$ MODULE$;

    static {
        new Utils$();
    }

    public String numClasses(CaffeNetwork net) {
        try {
            return String.valueOf(BoxesRunTime.boxToLong((long)new StringOps(Predef$.MODULE$.augmentString((String)net.getCaffeLayer((String)net.getLayers().last()).outputShape()._1())).toLong()));
        }
        catch (Throwable throwable) {
            Caffe2DML$.MODULE$.LOG().warn((Object)"Cannot infer the number of classes from network definition. User needs to pass it via set(num_classes=...) method.");
            return "$num_classes";
        }
    }

    public void prettyPrintDMLScript(String script) {
        BufferedReader bufReader = new BufferedReader(new StringReader(script));
        String line = bufReader.readLine();
        int lineNum = 1;
        while (line != null) {
            System.out.println(new StringBuilder().append((Object)new StringOps(Predef$.MODULE$.augmentString("%03d")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)lineNum)}))).append((Object)"|").append((Object)line).toString());
            ++lineNum;
            line = bufReader.readLine();
        }
        return;
    }

    public CaffeSolver parseSolver(String solverFilePath) {
        return this.parseSolver(this.readCaffeSolver(solverFilePath));
    }

    public CaffeSolver parseSolver(Caffe.SolverParameter solver) {
        block6: {
            CaffeSolver caffeSolver;
            block3: {
                String string;
                String regularizationType;
                double delta;
                double lambda;
                double momentum;
                block5: {
                    block4: {
                        block2: {
                            momentum = solver.hasMomentum() ? (double)solver.getMomentum() : 0.0;
                            lambda = solver.hasWeightDecay() ? (double)solver.getWeightDecay() : 0.0;
                            delta = solver.hasDelta() ? (double)solver.getDelta() : 0.0;
                            regularizationType = solver.hasRegularizationType() ? solver.getRegularizationType() : "L2";
                            string = solver.getType().toLowerCase();
                            if (!"sgd".equals(string)) break block2;
                            caffeSolver = new SGD(regularizationType, lambda, momentum);
                            break block3;
                        }
                        if (!"adagrad".equals(string)) break block4;
                        caffeSolver = new AdaGrad(regularizationType, lambda, delta);
                        break block3;
                    }
                    if (!"nesterov".equals(string)) break block5;
                    caffeSolver = new Nesterov(regularizationType, lambda, momentum);
                    break block3;
                }
                if (!"adam".equals(string)) break block6;
                caffeSolver = new Adam(regularizationType, lambda, momentum, solver.hasMomentum2() ? (double)solver.getMomentum2() : 0.0, delta);
            }
            return caffeSolver;
        }
        throw new DMLRuntimeException(new StringBuilder().append((Object)"The solver type is not supported: ").append((Object)solver.getType()).append((Object)". Try: SGD, AdaGrad or Nesterov or Adam.").toString());
    }

    public Caffe.NetParameter readCaffeNet(String netFilePath) {
        InputStreamReader reader = this.getInputStreamReader(netFilePath);
        Caffe.NetParameter.Builder builder = Caffe.NetParameter.newBuilder();
        TextFormat.merge((Readable)reader, (Message.Builder)builder);
        return builder.build();
    }

    public Tuple2<MatrixBlock, Utils.CopyFloatToDoubleArray> allocateDeconvolutionWeight(List<Float> data, int F, int C, int H, int W) {
        MatrixBlock mb = new MatrixBlock(C, F * H * W, false);
        mb.allocateDenseBlock();
        double[] arr = mb.getDenseBlockValues();
        Utils.CopyCaffeDeconvFloatToSystemMLDeconvDoubleArray thread = new Utils.CopyCaffeDeconvFloatToSystemMLDeconvDoubleArray(data, F, C, H, W, arr);
        thread.start();
        return new Tuple2((Object)mb, (Object)thread);
    }

    public Tuple2<MatrixBlock, Utils.CopyFloatToDoubleArray> allocateMatrixBlock(List<Float> data, int rows, int cols, boolean transpose) {
        MatrixBlock mb = new MatrixBlock(rows, cols, false);
        mb.allocateDenseBlock();
        double[] arr = mb.getDenseBlockValues();
        Utils.CopyFloatToDoubleArray thread = new Utils.CopyFloatToDoubleArray(data, rows, cols, transpose, arr);
        thread.start();
        return new Tuple2((Object)mb, (Object)thread);
    }

    public void validateShape(int[] shape, List<Float> data, String layerName) {
        if (shape == null) {
            throw new DMLRuntimeException(new StringBuilder().append((Object)"Unexpected weight for layer: ").append((Object)layerName).toString());
        }
        if (shape.length != 2) {
            throw new DMLRuntimeException(new StringBuilder().append((Object)"Expected shape to be of length 2:").append((Object)layerName).toString());
        }
        if (shape[0] * shape[1] != data.size()) {
            throw new DMLRuntimeException(new StringBuilder().append((Object)"Incorrect size of blob from caffemodel for the layer ").append((Object)layerName).append((Object)". Expected of size ").append((Object)BoxesRunTime.boxToInteger((int)(shape[0] * shape[1]))).append((Object)", but found ").append((Object)BoxesRunTime.boxToInteger((int)data.size())).toString());
        }
    }

    public void saveCaffeModelFile(JavaSparkContext sc, String deployFilePath, String caffeModelFilePath, String outputDirectory, String format) {
        this.saveCaffeModelFile(sc.sc(), deployFilePath, caffeModelFilePath, outputDirectory, format);
    }

    public void saveCaffeModelFile(SparkContext sc, String deployFilePath, String caffeModelFilePath, String outputDirectory, String format) {
        HashMap<String, MatrixBlock> inputVariables = new HashMap<String, MatrixBlock>();
        this.readCaffeNet(new CaffeNetwork(deployFilePath), deployFilePath, caffeModelFilePath, inputVariables);
        MLContext ml = new MLContext(sc);
        StringBuilder dmlScript = new StringBuilder();
        if (JavaConversions$.MODULE$.mapAsScalaMap(inputVariables).keys().size() == 0) {
            throw new DMLRuntimeException(new StringBuilder().append((Object)"No weights found in the file ").append((Object)caffeModelFilePath).toString());
        }
        JavaConversions$.MODULE$.mapAsScalaMap(inputVariables).keys().foreach((Function1)new Serializable(outputDirectory, format, dmlScript){
            public static final long serialVersionUID = 0L;
            private final String outputDirectory$1;
            private final String format$1;
            private final StringBuilder dmlScript$1;

            public final StringBuilder apply(String input) {
                return this.dmlScript$1.append(new StringBuilder().append((Object)"write(").append((Object)input).append((Object)", \"").append((Object)this.outputDirectory$1).append((Object)"/").append((Object)input).append((Object)".mtx\", format=\"").append((Object)this.format$1).append((Object)"\");\n").toString());
            }
            {
                this.outputDirectory$1 = outputDirectory$1;
                this.format$1 = format$1;
                this.dmlScript$1 = dmlScript$1;
            }
        });
        if (Caffe2DML$.MODULE$.LOG().isDebugEnabled()) {
            Caffe2DML$.MODULE$.LOG().debug((Object)new StringBuilder().append((Object)"Executing the script:").append((Object)dmlScript.toString()).toString());
        }
        Script script = ScriptFactory.dml((String)dmlScript.toString()).in((Map)JavaConversions$.MODULE$.mapAsScalaMap(inputVariables));
        ml.execute(script);
    }

    public Caffe.NetParameter readCaffeNet(CaffeNetwork net, String netFilePath, String weightsFilePath, HashMap<String, MatrixBlock> inputVariables) {
        InputStreamReader reader = this.getInputStreamReader(netFilePath);
        Caffe.NetParameter.Builder builder = Caffe.NetParameter.newBuilder();
        TextFormat.merge((Readable)reader, (Message.Builder)builder);
        CodedInputStream inputStream = CodedInputStream.newInstance((InputStream)new FileInputStream(weightsFilePath));
        inputStream.setSizeLimit(Integer.MAX_VALUE);
        builder.mergeFrom(inputStream);
        Caffe.NetParameter net1 = builder.build();
        ArrayList asyncThreads = new ArrayList();
        scala.collection.immutable.Map v1Layers = ((TraversableOnce)JavaConversions$.MODULE$.asScalaBuffer(net1.getLayersList()).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple2<String, Caffe.V1LayerParameter> apply(Caffe.V1LayerParameter layer) {
                return Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)layer.getName()), (Object)layer);
            }
        }, Buffer$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), JavaConversions$.MODULE$.asScalaBuffer(net1.getLayerList()).length()).foreach$mVc$sp((Function1)new Serializable(net, inputVariables, net1, asyncThreads, v1Layers){
            public static final long serialVersionUID = 0L;
            private final CaffeNetwork net$1;
            private final HashMap inputVariables$1;
            private final Caffe.NetParameter net1$1;
            private final ArrayList asyncThreads$1;
            private final scala.collection.immutable.Map v1Layers$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Caffe.LayerParameter layer;
                block12: {
                    block10: {
                        Caffe.ConvolutionParameter convParam;
                        List<Caffe.BlobProto> blobs;
                        block11: {
                            Object object;
                            block9: {
                                layer = this.net1$1.getLayerList().get(i);
                                blobs = Utils$.MODULE$.getBlobs(layer, (scala.collection.immutable.Map<String, Caffe.V1LayerParameter>)this.v1Layers$1);
                                if (blobs != null && blobs.size() != 0) break block9;
                                Caffe2DML$.MODULE$.LOG().debug((Object)new StringBuilder().append((Object)"The layer:").append((Object)layer.getName()).append((Object)" has no blobs").toString());
                                break block10;
                            }
                            if (blobs.size() != 2 && (blobs.size() != 3 || !(this.net$1.getCaffeLayer(layer.getName()) instanceof BatchNorm))) break block11;
                            CaffeLayer caffe2DMLLayer = this.net$1.getCaffeLayer(layer.getName());
                            boolean transpose = caffe2DMLLayer instanceof InnerProduct;
                            int[] shape = caffe2DMLLayer.weightShape();
                            if (shape == null) {
                                throw new DMLRuntimeException(new StringBuilder().append((Object)"Didnot expect weights for the layer ").append((Object)layer.getName()).toString());
                            }
                            if (caffe2DMLLayer instanceof DeConvolution) {
                                List<Float> data = ((Caffe.BlobProto)JavaConversions$.MODULE$.asScalaBuffer(blobs).apply(0)).getDataList();
                                Utils$.MODULE$.validateShape(shape, data, layer.getName());
                                DeConvolution deconvLayer = (DeConvolution)caffe2DMLLayer;
                                int C = shape[0];
                                int F = new StringOps(Predef$.MODULE$.augmentString(deconvLayer.numKernels())).toInt();
                                int Hf = new StringOps(Predef$.MODULE$.augmentString(deconvLayer.kernel_h())).toInt();
                                int Wf = new StringOps(Predef$.MODULE$.augmentString(deconvLayer.kernel_w())).toInt();
                                Tuple2<MatrixBlock, Utils.CopyFloatToDoubleArray> ret1 = Utils$.MODULE$.allocateDeconvolutionWeight(data, F, C, Hf, Wf);
                                this.asyncThreads$1.add(ret1._2());
                                object = this.inputVariables$1.put(caffe2DMLLayer.weight(), ret1._1());
                            } else {
                                object = this.inputVariables$1.put(caffe2DMLLayer.weight(), Utils$.MODULE$.getMBFromBlob((Caffe.BlobProto)JavaConversions$.MODULE$.asScalaBuffer(blobs).apply(0), shape, layer.getName(), transpose, this.asyncThreads$1));
                            }
                            int[] biasShape = caffe2DMLLayer.biasShape();
                            if (biasShape == null) {
                                throw new DMLRuntimeException(new StringBuilder().append((Object)"Didnot expect bias for the layer ").append((Object)layer.getName()).toString());
                            }
                            this.inputVariables$1.put(caffe2DMLLayer.bias(), Utils$.MODULE$.getMBFromBlob((Caffe.BlobProto)JavaConversions$.MODULE$.asScalaBuffer(blobs).apply(1), biasShape, layer.getName(), transpose, this.asyncThreads$1));
                            Caffe2DML$.MODULE$.LOG().debug((Object)new StringBuilder().append((Object)"Read weights/bias for layer:").append((Object)layer.getName()).toString());
                            break block10;
                        }
                        if (blobs.size() != 1) break block12;
                        CaffeLayer caffe2DMLLayer = this.net$1.getCaffeLayer(layer.getName());
                        Caffe.ConvolutionParameter convolutionParameter = convParam = (caffe2DMLLayer instanceof Convolution || caffe2DMLLayer instanceof DeConvolution) && caffe2DMLLayer.param().hasConvolutionParam() ? caffe2DMLLayer.param().getConvolutionParam() : null;
                        if (convParam == null) {
                            throw new DMLRuntimeException(new StringBuilder().append((Object)"Layer with blob count ").append((Object)BoxesRunTime.boxToInteger((int)layer.getBlobsCount())).append((Object)" is not supported for the layer ").append((Object)layer.getName()).toString());
                        }
                        if (convParam.hasBiasTerm() && convParam.getBiasTerm()) {
                            throw new DMLRuntimeException(new StringBuilder().append((Object)"Layer with blob count ").append((Object)BoxesRunTime.boxToInteger((int)layer.getBlobsCount())).append((Object)" and with bias term is not supported for the layer ").append((Object)layer.getName()).toString());
                        }
                        this.inputVariables$1.put(caffe2DMLLayer.weight(), Utils$.MODULE$.getMBFromBlob((Caffe.BlobProto)JavaConversions$.MODULE$.asScalaBuffer(blobs).apply(0), caffe2DMLLayer.weightShape(), layer.getName(), false, this.asyncThreads$1));
                        this.inputVariables$1.put(caffe2DMLLayer.bias(), new MatrixBlock(convParam.getNumOutput(), 1, false));
                        Caffe2DML$.MODULE$.LOG().debug((Object)new StringBuilder().append((Object)"Read only weight for layer:").append((Object)layer.getName()).toString());
                    }
                    return;
                }
                throw new DMLRuntimeException(new StringBuilder().append((Object)"Layer with blob count ").append((Object)BoxesRunTime.boxToInteger((int)layer.getBlobsCount())).append((Object)" is not supported for the layer ").append((Object)layer.getName()).toString());
            }
            {
                this.net$1 = net$1;
                this.inputVariables$1 = inputVariables$1;
                this.net1$1 = net1$1;
                this.asyncThreads$1 = asyncThreads$1;
                this.v1Layers$1 = v1Layers$1;
            }
        });
        JavaConversions$.MODULE$.asScalaBuffer(asyncThreads).foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Utils.CopyFloatToDoubleArray t) {
                t.join();
            }
        });
        JavaConversions$.MODULE$.collectionAsScalaIterable(inputVariables.values()).foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final long apply(MatrixBlock mb) {
                return mb.recomputeNonZeros();
            }
        });
        return this.readCaffeNet(netFilePath);
    }

    public List<Caffe.BlobProto> getBlobs(Caffe.LayerParameter layer, scala.collection.immutable.Map<String, Caffe.V1LayerParameter> v1Layers) {
        return layer.getBlobsCount() != 0 ? layer.getBlobsList() : (v1Layers.contains((Object)layer.getName()) ? ((Caffe.V1LayerParameter)v1Layers.get((Object)layer.getName()).get()).getBlobsList() : null);
    }

    public MatrixBlock getMBFromBlob(Caffe.BlobProto blob, int[] shape, String layerName, boolean transpose, ArrayList<Utils.CopyFloatToDoubleArray> asyncThreads) {
        List<Float> data = blob.getDataList();
        this.validateShape(shape, data, layerName);
        Tuple2<MatrixBlock, Utils.CopyFloatToDoubleArray> ret1 = this.allocateMatrixBlock(data, shape[0], shape[1], transpose);
        asyncThreads.add((Utils.CopyFloatToDoubleArray)ret1._2());
        return (MatrixBlock)ret1._1();
    }

    public Caffe.SolverParameter readCaffeSolver(String solverFilePath) {
        InputStreamReader reader = this.getInputStreamReader(solverFilePath);
        Caffe.SolverParameter.Builder builder = Caffe.SolverParameter.newBuilder();
        TextFormat.merge((Readable)reader, (Message.Builder)builder);
        return builder.build();
    }

    public void writeToFile(String content, String filePath) {
        PrintWriter pw = new PrintWriter(new File(filePath));
        pw.write(content);
        pw.close();
    }

    public InputStreamReader getInputStreamReader(String filePath) {
        if (filePath == null) {
            throw new LanguageException("file path was not specified!");
        }
        if (filePath.startsWith("hdfs:") || filePath.startsWith("gpfs:")) {
            FileSystem fs = FileSystem.get((Configuration)ConfigurationManager.getCachedJobConf());
            return new InputStreamReader((InputStream)fs.open(new Path(filePath)));
        }
        return new InputStreamReader((InputStream)new FileInputStream(new File(filePath)), "ASCII");
    }

    public long getMemInBytes(CaffeLayer l, int batchSize, boolean isTraining) {
        long l2;
        long numLayerOutput;
        long numLayerInput = l instanceof Data ? 0L : new StringOps(Predef$.MODULE$.augmentString((String)l.bottomLayerOutputShape()._1())).toLong() * new StringOps(Predef$.MODULE$.augmentString((String)l.bottomLayerOutputShape()._2())).toLong() * new StringOps(Predef$.MODULE$.augmentString((String)l.bottomLayerOutputShape()._3())).toLong() * (long)batchSize;
        long numLayerError = numLayerOutput = new StringOps(Predef$.MODULE$.augmentString((String)l.outputShape()._1())).toLong() * new StringOps(Predef$.MODULE$.augmentString((String)l.outputShape()._2())).toLong() * new StringOps(Predef$.MODULE$.augmentString((String)l.outputShape()._3())).toLong() * (long)batchSize;
        if (l.weightShape() != null) {
            long nWt = (long)l.weightShape()[0] * (long)l.weightShape()[1];
            l2 = l.extraWeightShape() != null ? (long)l.extraWeightShape()[0] * (long)l.extraWeightShape()[1] + nWt : nWt;
        } else {
            l2 = 0L;
        }
        long numLayerWeights = l2;
        long numLayerBias = l.biasShape() != null ? (long)l.biasShape()[0] * (long)l.biasShape()[1] : 0L;
        long numLayerGradients = (numLayerWeights + numLayerBias) * (long)batchSize;
        return isTraining ? (numLayerInput + numLayerOutput + numLayerError + numLayerWeights + numLayerBias + numLayerGradients) * 8L : (numLayerInput + numLayerOutput + numLayerWeights + numLayerBias) * 8L;
    }

    private Utils$() {
        MODULE$ = this;
    }
}

