package org.deeplearning4j.optimize;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/GradientAdjustment.class */
public class GradientAdjustment {
    private static final Logger log = LoggerFactory.getLogger(GradientAdjustment.class);

    private GradientAdjustment() {
    }

    @Deprecated
    public static void updateGradientAccordingToParams(int i, int i2, NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, INDArray iNDArray2, GradientUpdater gradientUpdater, INDArray iNDArray3, String str) {
        int intValue;
        if (gradientUpdater == null) {
            gradientUpdater = new AdaGrad(iNDArray.shape());
        }
        if (iNDArray3 == null) {
            iNDArray3 = Nd4j.ones(iNDArray.shape());
        }
        double momentum = neuralNetConfiguration.getMomentum();
        if (neuralNetConfiguration.getMomentumAfter() != null && !neuralNetConfiguration.getMomentumAfter().isEmpty() && i >= (intValue = neuralNetConfiguration.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = neuralNetConfiguration.getMomentumAfter().get(Integer.valueOf(intValue)).doubleValue();
        }
        if (neuralNetConfiguration.getRmsDecay() > 0.0d) {
            iNDArray3.assign(iNDArray3.mul(Double.valueOf(neuralNetConfiguration.getRmsDecay())).addi(Transforms.pow(iNDArray2, 2).muli(Double.valueOf(1.0d - neuralNetConfiguration.getRmsDecay()))));
            iNDArray2 = iNDArray2.mul(Double.valueOf(neuralNetConfiguration.getLr())).negi().divi(Transforms.sqrt(iNDArray3.add(Double.valueOf(Nd4j.EPS_THRESHOLD))));
        }
        INDArray gradient = gradientUpdater.getGradient(iNDArray2, 0);
        if (momentum > 0.0d) {
            gradient = iNDArray3.mul(Double.valueOf(momentum)).subi(gradient);
            iNDArray3.assign(gradient);
        }
        if (neuralNetConfiguration.isUseRegularization() && neuralNetConfiguration.getL2() > 0.0d && !gradient.equals("b")) {
            gradient.subi(iNDArray.mul(Double.valueOf(neuralNetConfiguration.getL2())));
        } else if (neuralNetConfiguration.isUseRegularization() && neuralNetConfiguration.getL1() < 0.0d && !gradient.equals("b")) {
            gradient.subi(Transforms.sign(iNDArray).muli(Double.valueOf(neuralNetConfiguration.getL1())));
        }
        if (neuralNetConfiguration.isConstrainGradientToUnitNorm()) {
            gradient.divi(gradient.norm2(new int[]{Integer.MAX_VALUE}));
        }
        gradient.divi(Integer.valueOf(i2));
    }
}
