package org.deeplearning4j.optimize.optimizers.rbm;

import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/optimize/optimizers/rbm/RBMOptimizer.class */
public class RBMOptimizer extends NeuralNetworkOptimizer {
    private static final long serialVersionUID = 3676032651650426749L;
    protected int k;
    protected int numTimesIterated;

    public RBMOptimizer(BaseNeuralNetwork baseNeuralNetwork, float f, Object[] objArr, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, LossFunctions.LossFunction lossFunction) {
        super(baseNeuralNetwork, f, objArr, optimizationAlgorithm, lossFunction);
        this.k = -1;
        this.numTimesIterated = 0;
        if (this.extraParams.length == 1 && this.extraParams[0] == null) {
            this.extraParams[0] = 1;
        }
    }

    @Override // org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer, org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getValueGradient(int i) {
        int intValue = (this.extraParams == null || this.extraParams.length >= 1) ? ((Integer) this.extraParams[0]).intValue() : 1;
        this.numTimesIterated++;
        if (this.k <= 0) {
            this.k = intValue;
        }
        if (this.numTimesIterated % 10 == 0) {
            this.k++;
        }
        return super.getValueGradient(i);
    }
}
