/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.strategy;

import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.Strategy;
import org.encog.neural.networks.training.Train;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SmartMomentum
implements Strategy {
    public static final double MIN_IMPROVEMENT = 1.0E-4;
    public static final double MAX_MOMENTUM = 4.0;
    public static final double START_MOMENTUM = 0.1;
    public static final double MOMENTUM_INCREASE = 0.01;
    public static final double MOMENTUM_CYCLES = 10.0;
    private Train train;
    private Momentum setter;
    private double lastImprovement;
    private double lastError;
    private boolean ready;
    private int lastMomentum;
    private double currentMomentum;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Override
    public void init(Train train) {
        this.train = train;
        this.setter = (Momentum)((Object)train);
        this.ready = false;
        this.setter.setMomentum(0.0);
        this.currentMomentum = 0.0;
    }

    @Override
    public void postIteration() {
        if (this.ready) {
            double currentError = this.train.getError();
            this.lastImprovement = (currentError - this.lastError) / this.lastError;
            if (this.logger.isTraceEnabled()) {
                this.logger.trace("Last improvement: {}", (Object)this.lastImprovement);
            }
            if (this.lastImprovement > 0.0 || Math.abs(this.lastImprovement) < 1.0E-4) {
                ++this.lastMomentum;
                if ((double)this.lastMomentum > 10.0) {
                    this.lastMomentum = 0;
                    if ((int)this.currentMomentum == 0) {
                        this.currentMomentum = 0.1;
                    }
                    this.currentMomentum *= 1.01;
                    this.setter.setMomentum(this.currentMomentum);
                    if (this.logger.isDebugEnabled()) {
                        this.logger.trace("Adjusting momentum: {}", (Object)this.currentMomentum);
                    }
                }
            } else {
                if (this.logger.isDebugEnabled()) {
                    this.logger.trace("Setting momentum back to zero.");
                }
                this.currentMomentum = 0.0;
                this.setter.setMomentum(0.0);
            }
        } else {
            this.ready = true;
        }
    }

    @Override
    public void preIteration() {
        this.lastError = this.train.getError();
    }
}

