/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.back;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.strategy.SmartLearningRate;
import org.encog.neural.networks.training.strategy.SmartMomentum;
import org.encog.util.validate.ValidateNetwork;

public class Backpropagation
extends Propagation
implements Momentum,
LearningRate {
    public static final String LAST_DELTA = "LAST_DELTA";
    private double learningRate;
    private double momentum;
    private double[] lastDelta;

    public Backpropagation(ContainsFlat network, MLDataSet training) {
        this(network, training, 0.0, 0.0);
        this.addStrategy(new SmartLearningRate());
        this.addStrategy(new SmartMomentum());
    }

    public Backpropagation(ContainsFlat network, MLDataSet training, double theLearnRate, double theMomentum) {
        super(network, training);
        ValidateNetwork.validateMethodToData(network, training);
        this.momentum = theMomentum;
        this.learningRate = theLearnRate;
        this.lastDelta = new double[network.getFlat().getWeights().length];
    }

    @Override
    public final boolean canContinue() {
        return false;
    }

    public final double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override
    public final double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public final double getMomentum() {
        return this.momentum;
    }

    public final boolean isValidResume(TrainingContinuation state) {
        if (!state.getContents().containsKey(LAST_DELTA)) {
            return false;
        }
        if (!state.getTrainingType().equals(this.getClass().getSimpleName())) {
            return false;
        }
        double[] d = (double[])state.get(LAST_DELTA);
        return d.length == ((ContainsFlat)this.getMethod()).getFlat().getWeights().length;
    }

    @Override
    public final TrainingContinuation pause() {
        TrainingContinuation result = new TrainingContinuation();
        result.setTrainingType(this.getClass().getSimpleName());
        result.set(LAST_DELTA, this.lastDelta);
        return result;
    }

    @Override
    public final void resume(TrainingContinuation state) {
        if (!this.isValidResume(state)) {
            throw new TrainingError("Invalid training resume data length");
        }
        this.lastDelta = (double[])state.get(LAST_DELTA);
    }

    @Override
    public final void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    @Override
    public final void setMomentum(double m) {
        this.momentum = m;
    }

    @Override
    public final double updateWeight(double[] gradients, double[] lastGradient, int index) {
        double delta;
        this.lastDelta[index] = delta = gradients[index] * this.learningRate + this.lastDelta[index] * this.momentum;
        return delta;
    }

    @Override
    public void initOthers() {
    }
}

