/*
 * Decompiled with CFR 0.152.
 */
package ml.comet.examples.mnist;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import java.io.IOException;
import ml.comet.experiment.ExperimentBuilder;
import ml.comet.experiment.OnlineExperiment;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MnistExperimentExample {
    private static final Logger log = LoggerFactory.getLogger(MnistExperimentExample.class);
    @Parameter(names={"--epochs", "-e"}, description="number of epochs to perform")
    final int numEpochs = 2;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void main(String[] args) {
        MnistExperimentExample main = new MnistExperimentExample();
        JCommander.newBuilder().addObject((Object)main).build().parse(args);
        OnlineExperiment experiment = (OnlineExperiment)ExperimentBuilder.OnlineExperiment().interceptStdout().build();
        try {
            main.runMnistExperiment(experiment);
        }
        catch (Exception e) {
            System.out.println("--- Failed to run experiment ---");
            e.printStackTrace();
        }
        finally {
            experiment.end();
        }
    }

    public void runMnistExperiment(OnlineExperiment experiment) throws IOException {
        log.info("****************MNIST Experiment Example Started********************");
        int numRows = 28;
        int numColumns = 28;
        int outputNum = 10;
        int batchSize = 128;
        int rngSeed = 123;
        experiment.logParameter("numRows", (Object)28);
        experiment.logParameter("numColumns", (Object)28);
        experiment.logParameter("outputNum", (Object)outputNum);
        experiment.logParameter("batchSize", (Object)batchSize);
        experiment.logParameter("rngSeed", (Object)rngSeed);
        experiment.logParameter("numEpochs", (Object)2);
        double lr = 0.006;
        double nesterovsMomentum = 0.9;
        double l2Regularization = 1.0E-4;
        experiment.logParameter("learningRate", (Object)lr);
        experiment.logParameter("nesterovsMomentum", (Object)nesterovsMomentum);
        experiment.logParameter("l2Regularization", (Object)l2Regularization);
        OptimizationAlgorithm optimizationAlgorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
        experiment.logParameter("optimizationAlgorithm", (Object)OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed((long)rngSeed).updater((IUpdater)new Nesterovs(lr, nesterovsMomentum)).optimizationAlgo(optimizationAlgorithm).l2(l2Regularization).list().layer((Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(784)).nOut(1000)).activation(Activation.RELU)).weightInit(WeightInit.XAVIER)).build()).layer((Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000)).nOut(outputNum)).activation(Activation.SOFTMAX)).weightInit(WeightInit.XAVIER)).build()).build();
        experiment.logGraph(conf.toJson());
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new TrainingListener[]{new StepScoreListener(experiment, 1, log)});
        MnistDataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        log.info("Train model....");
        model.fit((DataSetIterator)mnistTrain, 2);
        MnistDataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
        log.info("Evaluate model....");
        Evaluation eval = model.evaluate((DataSetIterator)mnistTest);
        log.info(eval.stats());
        experiment.logHtml(eval.getConfusionMatrix().toHTML(), false);
        log.info("****************MNIST Experiment Example finished********************");
    }

    static class StepScoreListener
    extends BaseTrainingListener {
        private final OnlineExperiment experiment;
        private int printIterations;
        private final Logger log;

        StepScoreListener(OnlineExperiment experiment, int printIterations, Logger log) {
            this.experiment = experiment;
            this.printIterations = printIterations;
            this.log = log;
        }

        public void iterationDone(Model model, int iteration, int epoch) {
            if (this.printIterations <= 0) {
                this.printIterations = 1;
            }
            if (iteration % this.printIterations == 0) {
                double result = model.score();
                this.log.info("Score at step/epoch {}/{}  is {} ", new Object[]{iteration, epoch, result});
                this.experiment.setEpoch((long)epoch);
                this.experiment.logMetric("score", (Object)model.score(), (long)iteration);
            }
        }
    }
}

