package org.deeplearning4j.plot;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.mnist.draw.DrawReconstruction;
import org.deeplearning4j.models.featuredetectors.autoencoder.SemanticHashing;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.transformation.MatrixTransform;

/* loaded from: input_file:org/deeplearning4j/plot/DeepAutoEncoderDataSetReconstructionRender.class */
public class DeepAutoEncoderDataSetReconstructionRender {
    private DataSetIterator iter;
    private SemanticHashing encoder;
    private int rows;
    private int columns;
    private MatrixTransform picDraw;
    private MatrixTransform reconDraw;

    public DeepAutoEncoderDataSetReconstructionRender(DataSetIterator dataSetIterator, SemanticHashing semanticHashing, int i, int i2) {
        this.iter = dataSetIterator;
        this.encoder = semanticHashing;
        this.rows = i;
        this.columns = i2;
    }

    public void draw() throws InterruptedException {
        while (this.iter.hasNext()) {
            DataSet next = this.iter.next();
            INDArray output = this.encoder.output(next.getFeatureMatrix());
            for (int i = 0; i < next.numExamples(); i++) {
                INDArray mul = next.get(i).getFeatureMatrix().mul(255);
                if (this.picDraw != null) {
                    mul = (INDArray) this.picDraw.apply(mul);
                }
                INDArray row = output.getRow(i);
                if (this.reconDraw != null) {
                    row = (INDArray) this.reconDraw.apply(row);
                }
                INDArray mul2 = row.mul(255);
                DrawReconstruction drawReconstruction = new DrawReconstruction(mul.reshape(this.rows, this.columns));
                drawReconstruction.title = "REAL";
                drawReconstruction.draw();
                DrawReconstruction drawReconstruction2 = new DrawReconstruction(mul2.reshape(this.rows, this.columns), 1000, 1000);
                drawReconstruction2.title = "TEST";
                drawReconstruction2.draw();
                Thread.sleep(10000L);
                drawReconstruction.frame.dispose();
                drawReconstruction2.frame.dispose();
            }
        }
    }

    public MatrixTransform getPicDraw() {
        return this.picDraw;
    }

    public void setPicDraw(MatrixTransform matrixTransform) {
        this.picDraw = matrixTransform;
    }

    public MatrixTransform getReconDraw() {
        return this.reconDraw;
    }

    public void setReconDraw(MatrixTransform matrixTransform) {
        this.reconDraw = matrixTransform;
    }
}
