/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.cpn.training;

import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.cpn.CPN;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

public class TrainOutstar
extends BasicTraining
implements LearningRate {
    private double learningRate;
    private final CPN network;
    private final MLDataSet training;
    private boolean mustInit = true;

    public TrainOutstar(CPN theNetwork, MLDataSet theTraining, double theLearningRate) {
        super(TrainingImplementationType.Iterative);
        this.network = theNetwork;
        this.training = theTraining;
        this.learningRate = theLearningRate;
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    private void initWeight() {
        for (int i = 0; i < this.network.getOutstarCount(); ++i) {
            int j = 0;
            for (MLDataPair pair : this.training) {
                this.network.getWeightsInstarToOutstar().set(j++, i, pair.getIdeal().getData(i));
            }
        }
        this.mustInit = false;
    }

    @Override
    public void iteration() {
        if (this.mustInit) {
            this.initWeight();
        }
        ErrorCalculation error = new ErrorCalculation();
        for (MLDataPair pair : this.training) {
            MLData out = this.network.computeInstar(pair.getInput());
            int j = EngineArray.indexOfLargest(out.getData());
            for (int i = 0; i < this.network.getOutstarCount(); ++i) {
                double delta = this.learningRate * (pair.getIdeal().getData(i) - this.network.getWeightsInstarToOutstar().get(j, i));
                this.network.getWeightsInstarToOutstar().add(j, i, delta);
            }
            MLData out2 = this.network.computeOutstar(out);
            error.updateError(out2.getData(), pair.getIdeal().getData(), pair.getSignificance());
        }
        this.setError(error.calculate());
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }
}

