package aima.core.learning.neural;

import aima.core.util.math.Matrix;
import aima.core.util.math.Vector;

/* loaded from: input_file:aima/core/learning/neural/BackPropLearning.class */
public class BackPropLearning implements NNTrainingScheme {
    private final double learningRate;
    private final double momentum;
    private Layer hiddenLayer;
    private Layer outputLayer;
    private LayerSensitivity hiddenSensitivity;
    private LayerSensitivity outputSensitivity;

    public BackPropLearning(double d, double d2) {
        this.learningRate = d;
        this.momentum = d2;
    }

    @Override // aima.core.learning.neural.NNTrainingScheme
    public void setNeuralNetwork(FunctionApproximator functionApproximator) {
        FeedForwardNeuralNetwork feedForwardNeuralNetwork = (FeedForwardNeuralNetwork) functionApproximator;
        this.hiddenLayer = feedForwardNeuralNetwork.getHiddenLayer();
        this.outputLayer = feedForwardNeuralNetwork.getOutputLayer();
        this.hiddenSensitivity = new LayerSensitivity(this.hiddenLayer);
        this.outputSensitivity = new LayerSensitivity(this.outputLayer);
    }

    @Override // aima.core.learning.neural.NNTrainingScheme
    public Vector processInput(FeedForwardNeuralNetwork feedForwardNeuralNetwork, Vector vector) {
        this.hiddenLayer.feedForward(vector);
        this.outputLayer.feedForward(this.hiddenLayer.getLastActivationValues());
        return this.outputLayer.getLastActivationValues();
    }

    @Override // aima.core.learning.neural.NNTrainingScheme
    public void processError(FeedForwardNeuralNetwork feedForwardNeuralNetwork, Vector vector) {
        this.outputSensitivity.sensitivityMatrixFromErrorMatrix(vector);
        this.hiddenSensitivity.sensitivityMatrixFromSucceedingLayer(this.outputSensitivity);
        calculateWeightUpdates(this.outputSensitivity, this.hiddenLayer.getLastActivationValues(), this.learningRate, this.momentum);
        calculateWeightUpdates(this.hiddenSensitivity, this.hiddenLayer.getLastInputValues(), this.learningRate, this.momentum);
        calculateBiasUpdates(this.outputSensitivity, this.learningRate, this.momentum);
        calculateBiasUpdates(this.hiddenSensitivity, this.learningRate, this.momentum);
        this.outputLayer.updateWeights();
        this.outputLayer.updateBiases();
        this.hiddenLayer.updateWeights();
        this.hiddenLayer.updateBiases();
    }

    public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector vector, double d, double d2) {
        Layer layer = layerSensitivity.getLayer();
        Matrix plus = layer.getLastWeightUpdateMatrix().times(d2).plus(layerSensitivity.getSensitivityMatrix().times(vector.transpose()).times(d).times(-1.0d).times(1.0d - d2));
        layer.acceptNewWeightUpdate(plus.copy());
        return plus;
    }

    public static Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector vector, double d) {
        Layer layer = layerSensitivity.getLayer();
        Matrix times = layerSensitivity.getSensitivityMatrix().times(vector.transpose()).times(d).times(-1.0d);
        layer.acceptNewWeightUpdate(times.copy());
        return times;
    }

    public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double d, double d2) {
        Layer layer = layerSensitivity.getLayer();
        Matrix plus = layer.getLastBiasUpdateVector().times(d2).plus(layerSensitivity.getSensitivityMatrix().times(d).times(-1.0d).times(1.0d - d2));
        Vector vector = new Vector(plus.getRowDimension());
        for (int i = 0; i < plus.getRowDimension(); i++) {
            vector.setValue(i, plus.get(i, 0));
        }
        layer.acceptNewBiasUpdate(vector.copyVector());
        return vector;
    }

    public static Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double d) {
        Layer layer = layerSensitivity.getLayer();
        Matrix times = layerSensitivity.getSensitivityMatrix().times(d).times(-1.0d);
        Vector vector = new Vector(times.getRowDimension());
        for (int i = 0; i < times.getRowDimension(); i++) {
            vector.setValue(i, times.get(i, 0));
        }
        layer.acceptNewBiasUpdate(vector.copyVector());
        return vector;
    }
}
