/*
 * Decompiled with CFR 0.152.
 */
package ml.comet.examples.mnist;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import java.io.IOException;
import ml.comet.experiment.OnlineExperiment;
import ml.comet.experiment.OnlineExperimentImpl;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MnistExperimentExample {
    private static final Logger log = LoggerFactory.getLogger(MnistExperimentExample.class);
    @Parameter(names={"--epochs", "-e"}, description="number of epochs to perform")
    int numEpochs = 2;

    public static void main(String[] args) throws Exception {
        MnistExperimentExample main = new MnistExperimentExample();
        JCommander.newBuilder().addObject((Object)main).build().parse(args);
        main.runMnistExperiment();
    }

    public void runMnistExperiment() throws IOException {
        OnlineExperimentImpl experiment = OnlineExperimentImpl.builder().build();
        experiment.setInterceptStdout();
        System.out.println("experiment live at: " + experiment.getExperimentLink());
        int numRows = 28;
        int numColumns = 28;
        int outputNum = 10;
        int batchSize = 128;
        int rngSeed = 123;
        experiment.logParameter("numRows", (Object)28);
        experiment.logParameter("numColumns", (Object)28);
        experiment.logParameter("outputNum", (Object)outputNum);
        experiment.logParameter("batchSize", (Object)batchSize);
        experiment.logParameter("rngSeed", (Object)rngSeed);
        experiment.logParameter("numEpochs", (Object)this.numEpochs);
        double lr = 0.006;
        experiment.logParameter("lr", (Object)lr);
        OptimizationAlgorithm optimizationAlgorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
        experiment.logParameter("optimizationAlgorithm", (Object)OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(rngSeed).optimizationAlgo(optimizationAlgorithm).iterations(1).learningRate(lr).updater(Updater.NESTEROVS).momentum(0.9).regularization(true).l2(1.0E-4).list().layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(784)).nOut(1000)).activation(Activation.RELU)).weightInit(WeightInit.XAVIER)).build()).layer(1, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000)).nOut(outputNum)).activation(Activation.SOFTMAX)).weightInit(WeightInit.XAVIER)).build()).pretrain(false).backprop(true).build();
        experiment.logGraph(conf.toJson());
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new IterationListener[]{new StepScoreListener((OnlineExperiment)experiment, 1, log)});
        MnistDataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        log.info("Train model....");
        for (int i = 0; i < this.numEpochs; ++i) {
            experiment.setEpoch((long)i);
            model.fit((DataSetIterator)mnistTrain);
        }
        MnistDataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
        log.info("Evaluate model....");
        Evaluation eval = new Evaluation(outputNum);
        while (mnistTest.hasNext()) {
            DataSet next = (DataSet)mnistTest.next();
            INDArray output = model.output(next.getFeatureMatrix());
            eval.eval(next.getLabels(), output);
        }
        log.info(eval.stats());
        experiment.logHtml(eval.getConfusionMatrix().toHTML(), false);
        experiment.end();
        log.info("****************MNIST Experiment Example finished********************");
    }

    static class StepScoreListener
    implements IterationListener {
        private boolean invoked;
        private final OnlineExperiment experiment;
        private int printIterations;
        private final Logger log;
        private long iterCount;

        StepScoreListener(OnlineExperiment experiment, int printIterations, Logger log) {
            this.experiment = experiment;
            this.printIterations = printIterations;
            this.log = log;
        }

        public boolean invoked() {
            return this.invoked;
        }

        public void invoke() {
            this.invoked = true;
        }

        public void iterationDone(Model model, int iteration) {
            if (this.printIterations <= 0) {
                this.printIterations = 1;
            }
            if (this.iterCount % (long)this.printIterations == 0L) {
                this.invoke();
                double result = model.score();
                this.log.info("Score at iteration " + this.iterCount + " is " + result);
                this.experiment.logMetric("score", (Object)model.score(), this.iterCount);
            }
            ++this.iterCount;
        }
    }
}

