/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.cross;

import org.encog.engine.network.flat.FlatNetwork;
import org.encog.neural.data.folded.FoldedDataSet;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.cross.CrossTraining;
import org.encog.neural.networks.training.cross.NetworkFold;

public class CrossValidationKFold
extends CrossTraining {
    private final Train train;
    private final NetworkFold[] networks;
    private final FlatNetwork flatNetwork;

    public CrossValidationKFold(Train train, int k) {
        super(train.getNetwork(), (FoldedDataSet)train.getTraining());
        this.train = train;
        this.getFolded().fold(k);
        this.flatNetwork = train.getNetwork().getStructure().getFlat();
        this.networks = new NetworkFold[k];
        for (int i = 0; i < this.networks.length; ++i) {
            this.networks[i] = new NetworkFold(this.flatNetwork);
        }
    }

    @Override
    public void iteration() {
        double error = 0.0;
        for (int valFold = 0; valFold < this.getFolded().getNumFolds(); ++valFold) {
            this.networks[valFold].copyToNetwork(this.flatNetwork);
            for (int curFold = 0; curFold < this.getFolded().getNumFolds(); ++curFold) {
                if (curFold == valFold) continue;
                this.getFolded().setCurrentFold(curFold);
                this.train.iteration();
            }
            this.getFolded().setCurrentFold(valFold);
            double e = this.flatNetwork.calculateError(this.getFolded());
            error += e;
            this.networks[valFold].copyFromNetwork(this.flatNetwork);
        }
        this.setError(error / (double)this.getFolded().getNumFolds());
    }
}

