/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.resilient;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.train.prop.RPROPType;
import org.encog.neural.flat.train.prop.TrainFlatNetworkResilient;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

public class ResilientPropagation
extends Propagation {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    public static final String UPDATE_VALUES = "UPDATE_VALUES";

    public ResilientPropagation(ContainsFlat network, MLDataSet training) {
        this(network, training, 0.1, 50.0);
    }

    public ResilientPropagation(ContainsFlat network, MLDataSet training, double initialUpdate, double maxStep) {
        super(network, training);
        TrainFlatNetworkResilient rpropFlat = new TrainFlatNetworkResilient(network.getFlat(), this.getTraining(), 1.0E-17, initialUpdate, maxStep);
        this.setFlatTraining(rpropFlat);
    }

    @Override
    public final boolean canContinue() {
        return true;
    }

    public final boolean isValidResume(TrainingContinuation state) {
        if (!state.getContents().containsKey(LAST_GRADIENTS) || !state.getContents().containsKey(UPDATE_VALUES)) {
            return false;
        }
        if (!state.getTrainingType().equals(this.getClass().getSimpleName())) {
            return false;
        }
        double[] d = (double[])state.get(LAST_GRADIENTS);
        return d.length == ((ContainsFlat)this.getMethod()).getFlat().getWeights().length;
    }

    @Override
    public final TrainingContinuation pause() {
        TrainingContinuation result = new TrainingContinuation();
        result.setTrainingType(this.getClass().getSimpleName());
        result.set(LAST_GRADIENTS, ((TrainFlatNetworkResilient)this.getFlatTraining()).getLastGradient());
        result.set(UPDATE_VALUES, ((TrainFlatNetworkResilient)this.getFlatTraining()).getUpdateValues());
        return result;
    }

    @Override
    public final void resume(TrainingContinuation state) {
        if (!this.isValidResume(state)) {
            throw new TrainingError("Invalid training resume data length");
        }
        double[] lastGradient = (double[])state.get(LAST_GRADIENTS);
        double[] updateValues = (double[])state.get(UPDATE_VALUES);
        EngineArray.arrayCopy(lastGradient, ((TrainFlatNetworkResilient)this.getFlatTraining()).getLastGradient());
        EngineArray.arrayCopy(updateValues, ((TrainFlatNetworkResilient)this.getFlatTraining()).getUpdateValues());
    }

    public void setRPROPType(RPROPType t) {
        ((TrainFlatNetworkResilient)this.getFlatTraining()).setRpropType(t);
    }

    public RPROPType getRPROPType() {
        return ((TrainFlatNetworkResilient)this.getFlatTraining()).getRpropType();
    }
}

