package org.deeplearning4j.optimize;

import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/GradientAdjustment.class */
public class GradientAdjustment {
    private static final Logger log = LoggerFactory.getLogger(GradientAdjustment.class);

    private GradientAdjustment() {
    }

    public static void updateGradientAccordingToParams(NeuralNetConfiguration neuralNetConfiguration, int i, Gradient gradient, int i2, Map<String, AdaGrad> map, Model model) {
        AdaGrad adaGrad;
        for (String str : gradient.gradientForVariable().keySet()) {
            if (map.get(str) == null) {
                adaGrad = new AdaGrad(model.getParam(str).shape());
                map.put(str, adaGrad);
            } else {
                adaGrad = map.get(str);
            }
            updateGradientAccordingToParams(neuralNetConfiguration, i, adaGrad, gradient.getGradientFor(str), model.getParam(str), i2);
        }
    }

    public static void updateGradientAccordingToParams(NeuralNetConfiguration neuralNetConfiguration, int i, AdaGrad adaGrad, INDArray iNDArray, INDArray iNDArray2, int i2) {
        int intValue;
        if (adaGrad == null) {
            adaGrad = new AdaGrad(iNDArray.shape());
        }
        if (i != 0 && neuralNetConfiguration.getResetAdaGradIterations() > 0 && i % neuralNetConfiguration.getResetAdaGradIterations() == 0) {
            adaGrad.historicalGradient = null;
            log.info("Resetting adagrad");
        }
        double momentum = neuralNetConfiguration.getMomentum();
        if (neuralNetConfiguration.getMomentumAfter() != null && !neuralNetConfiguration.getMomentumAfter().isEmpty() && i >= (intValue = neuralNetConfiguration.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = neuralNetConfiguration.getMomentumAfter().get(Integer.valueOf(intValue)).doubleValue();
        }
        if (neuralNetConfiguration.isUseAdaGrad()) {
            iNDArray = adaGrad.getGradient(iNDArray);
        } else {
            iNDArray.muli(Double.valueOf(neuralNetConfiguration.getLr()));
        }
        if (momentum > 0.0d) {
            iNDArray.addi(iNDArray.mul(Double.valueOf(momentum)).addi(iNDArray.mul(Double.valueOf(1.0d - momentum))));
        }
        if (neuralNetConfiguration.isUseRegularization() && neuralNetConfiguration.getL2() > 0.0d) {
            iNDArray.subi(iNDArray2.mul(Double.valueOf(neuralNetConfiguration.getL2() * neuralNetConfiguration.getLr())));
        } else if (neuralNetConfiguration.isUseRegularization() && neuralNetConfiguration.getL1() < 0.0d) {
            iNDArray.muli(Transforms.sign(iNDArray2)).muli(Double.valueOf(neuralNetConfiguration.getL1()));
        }
        if (neuralNetConfiguration.isConstrainGradientToUnitNorm()) {
            iNDArray.divi(iNDArray.norm2(Integer.MAX_VALUE));
        }
        iNDArray.divi(Integer.valueOf(i2));
    }
}
