package org.deeplearning4j.nn.layers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BasePretrainNetwork.class */
public abstract class BasePretrainNetwork extends BaseLayer {
    private static final long serialVersionUID = -7074102204433996574L;

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void setScore() {
        if (this.input == null) {
            return;
        }
        this.score = -LossFunctions.score(this.input, this.conf.getLossFunction(), transform(this.input), this.conf.getL2(), this.conf.isUseRegularization());
        if (this.conf.isMinimize()) {
            this.score = -this.score;
        }
    }

    public INDArray getCorruptedInput(INDArray iNDArray, double d) {
        INDArray sample = Nd4j.getDistributions().createBinomial(1, 1.0d - d).sample(iNDArray.shape());
        sample.muli(iNDArray);
        return sample;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Gradient createGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        defaultGradient.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, iNDArray);
        return defaultGradient;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray);
}
