package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import java.util.List;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.gradient.OutputLayerGradient;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/optimizers/MultiLayerNetworkOptimizer.class */
public class MultiLayerNetworkOptimizer implements Serializable, OptimizableByGradientValueMatrix {
    private static final long serialVersionUID = -3012638773299331828L;
    protected BaseMultiLayerNetwork network;
    private static Logger log = LoggerFactory.getLogger(MultiLayerNetworkOptimizer.class);
    private double lr;
    private int currentIteration;

    public MultiLayerNetworkOptimizer(BaseMultiLayerNetwork baseMultiLayerNetwork, double d) {
        this.network = baseMultiLayerNetwork;
        this.lr = d;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setCurrentIteration(int i) {
        this.currentIteration = i;
    }

    public void optimize(INDArray iNDArray, double d, int i, TrainingEvaluator trainingEvaluator) {
        this.network.getOutputLayer().setLabels(iNDArray);
        if (!this.network.isForceNumEpochs()) {
            if (this.network.isShouldBackProp()) {
                this.network.backProp(d, i, trainingEvaluator);
                return;
            }
            return;
        }
        log.info("Training for " + i + " epochs");
        List<INDArray> feedForward = this.network.feedForward();
        INDArray iNDArray2 = feedForward.get(feedForward.size() - 1);
        for (int i2 = 0; i2 < i; i2++) {
            if (i2 % this.network.getDefaultConfiguration().getResetAdaGradIterations() == 0) {
                this.network.getOutputLayer().getAdaGrad().historicalGradient = null;
            }
            this.network.getOutputLayer().train(iNDArray2, iNDArray, d);
        }
        if (this.network.isShouldBackProp()) {
            this.network.backProp(d, i, trainingEvaluator);
        }
    }

    public void optimize(INDArray iNDArray, double d, int i) {
        this.network.getOutputLayer().setLabels(iNDArray);
        if (!this.network.isForceNumEpochs()) {
            this.network.backProp(d, i);
        } else {
            log.info("Training for " + i + " iteration");
            this.network.backProp(d, i);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public int getNumParameters() {
        return this.network.getOutputLayer().getW().length() + this.network.getOutputLayer().getB().length();
    }

    public void getParameters(double[] dArr) {
        int i = 0;
        for (int i2 = 0; i2 < this.network.getOutputLayer().getW().length(); i2++) {
            int i3 = i;
            i++;
            dArr[i3] = ((Double) this.network.getOutputLayer().getW().getScalar(i2).element()).doubleValue();
        }
        for (int i4 = 0; i4 < this.network.getOutputLayer().getB().length(); i4++) {
            int i5 = i;
            i++;
            dArr[i5] = ((Double) this.network.getOutputLayer().getB().getScalar(i4).element()).doubleValue();
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getParameter(int i) {
        if (i < this.network.getOutputLayer().getW().length()) {
            return ((Double) this.network.getOutputLayer().getW().getScalar(i).element()).doubleValue();
        }
        return ((Double) this.network.getOutputLayer().getB().getScalar(i - this.network.getOutputLayer().getB().length()).element()).doubleValue();
    }

    public void setParameters(double[] dArr) {
        int i = 0;
        for (int i2 = 0; i2 < this.network.getOutputLayer().getW().length(); i2++) {
            int i3 = i;
            i++;
            this.network.getOutputLayer().getW().putScalar(i2, dArr[i3]);
        }
        for (int i4 = 0; i4 < this.network.getOutputLayer().getB().length(); i4++) {
            int i5 = i;
            i++;
            this.network.getOutputLayer().getB().putScalar(i4, dArr[i5]);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameter(int i, double d) {
        if (i < this.network.getOutputLayer().getW().length()) {
            this.network.getOutputLayer().getW().putScalar(i, d);
        } else {
            this.network.getOutputLayer().getB().putScalar(i - this.network.getOutputLayer().getB().length(), d);
        }
    }

    public void getValueGradient(double[] dArr) {
        OutputLayerGradient gradient = this.network.getOutputLayer().getGradient(this.lr);
        INDArray iNDArray = gradient.getwGradient();
        INDArray iNDArray2 = gradient.getbGradient();
        int i = 0;
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            int i3 = i;
            i++;
            dArr[i3] = ((Double) iNDArray.getScalar(i2).element()).doubleValue();
        }
        for (int i4 = 0; i4 < iNDArray2.length(); i4++) {
            int i5 = i;
            i++;
            dArr[i5] = ((Double) iNDArray2.getScalar(i4).element()).doubleValue();
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getValue() {
        return this.network.score();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getParameters() {
        double[] dArr = new double[getNumParameters()];
        getParameters(dArr);
        return Nd4j.create(dArr);
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameters(INDArray iNDArray) {
        setParameters(iNDArray.data().asDouble());
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getValueGradient(int i) {
        double[] dArr = new double[getNumParameters()];
        getValueGradient(dArr);
        return Nd4j.create(dArr);
    }
}
