package org.deeplearning4j.nn.layers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BasePretrainNetwork.class */
public abstract class BasePretrainNetwork extends BaseLayer {
    private static final long serialVersionUID = -7074102204433996574L;
    protected INDArray doMask;
    private static Logger log = LoggerFactory.getLogger(BasePretrainNetwork.class);

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    protected void applySparsity(INDArray iNDArray) {
        iNDArray.addi(iNDArray.mul(Double.valueOf(this.conf.getSparsity())).mul(Double.valueOf((-this.conf.getLr()) * this.conf.getSparsity())));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY ? -LossFunctions.score(this.input, this.conf.getLossFunction(), transform(this.input), this.conf.getL2(), this.conf.isUseRegularization()) : -LossFunctions.reconEntropy(this.input, getParam("b"), getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY), getParam(DefaultParamInitializer.WEIGHT_KEY), this.conf.getActivationFunction());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        setParams(params().addi(gradient.gradient()));
    }

    public void iterate(INDArray iNDArray) {
        this.input = iNDArray;
        update(getGradient());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Gradient createGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientLookupTable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, iNDArray2);
        defaultGradient.gradientLookupTable().put("b", iNDArray3);
        defaultGradient.gradientLookupTable().put(DefaultParamInitializer.WEIGHT_KEY, iNDArray);
        return defaultGradient;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public void applyDropOutIfNecessary(INDArray iNDArray) {
        if (this.conf.getDropOut() > 0.0d) {
            this.doMask = Nd4j.rand(iNDArray.rows(), iNDArray.columns()).gt(Double.valueOf(this.conf.getDropOut()));
        } else {
            this.doMask = Nd4j.ones(iNDArray.rows(), iNDArray.columns());
        }
        iNDArray.muli(this.doMask);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit() {
        new Solver.Builder().model(this).configure(conf()).listeners(this.conf.getListeners()).build().optimize();
    }

    protected INDArray preProcessInput(INDArray iNDArray) {
        return this.conf.isConcatBiases() ? Nd4j.hstack(new INDArray[]{iNDArray, Nd4j.ones(iNDArray.rows(), 1)}) : iNDArray;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray);
}
