package org.deeplearning4j.example.mnist;

import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.dbn.DBN;
import org.deeplearning4j.eval.Evaluation;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/example/mnist/MnistExample.class */
public class MnistExample {
    private static Logger log = LoggerFactory.getLogger(MnistExample.class);

    public static void main(String[] strArr) throws IOException {
        MnistDataSetIterator mnistDataSetIterator = new MnistDataSetIterator(10, 60000);
        DBN build = new DBN.Builder().hiddenLayerSizes(new int[]{500, 400, 250}).numberOfInputs(784).numberOfOutPuts(10).useRegularization(false).build();
        while (mnistDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) mnistDataSetIterator.next();
            build.pretrain((DoubleMatrix) dataSet.getFirst(), 1, 0.01d, 1000);
            build.finetune((DoubleMatrix) dataSet.getSecond(), 0.01d, 1000);
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream("mnist-dbn.bin"));
        build.write(bufferedOutputStream);
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
        log.info("Saved dbn");
        mnistDataSetIterator.reset();
        Evaluation evaluation = new Evaluation();
        while (mnistDataSetIterator.hasNext()) {
            DataSet dataSet2 = (DataSet) mnistDataSetIterator.next();
            evaluation.eval((DoubleMatrix) dataSet2.getSecond(), build.predict((DoubleMatrix) dataSet2.getFirst()));
        }
        log.info("Prediciton f scores and accuracy");
        log.info(evaluation.stats());
    }
}
