/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ensemble;

import java.util.ArrayList;
import org.encog.ensemble.EnsembleAggregator;
import org.encog.ensemble.EnsembleML;
import org.encog.ensemble.EnsembleMLMethodFactory;
import org.encog.ensemble.EnsembleTrainFactory;
import org.encog.ensemble.EnsembleTypes;
import org.encog.ensemble.GenericEnsembleML;
import org.encog.ensemble.aggregator.WeightedAveraging;
import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.EnsembleDataSetFactory;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;

public abstract class Ensemble {
    private final int DEFAULT_MAX_ITERATIONS = 2000;
    protected EnsembleDataSetFactory dataSetFactory;
    protected EnsembleTrainFactory trainFactory;
    protected EnsembleAggregator aggregator;
    protected ArrayList<EnsembleML> members;
    protected EnsembleMLMethodFactory mlFactory;
    protected MLDataSet aggregatorDataSet;

    public abstract void initMembers();

    public EnsembleML generateNewMember() {
        GenericEnsembleML newML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
        newML.setTrainingSet(this.dataSetFactory.getNewDataSet());
        newML.setTraining(this.trainFactory.getTraining(newML.getMl(), newML.getTrainingSet()));
        return newML;
    }

    public void addNewMember() {
        this.members.add(this.generateNewMember());
    }

    public void initMembersBySplits(int splits) {
        if (this.dataSetFactory != null && splits > 0 && this.dataSetFactory.hasSource()) {
            for (int i = 0; i < splits; ++i) {
                GenericEnsembleML newML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
                newML.setTrainingSet(this.dataSetFactory.getNewDataSet());
                newML.setTraining(this.trainFactory.getTraining(newML.getMl(), newML.getTrainingSet()));
                this.members.add(newML);
            }
            if (this.aggregator.needsTraining()) {
                this.aggregatorDataSet = this.dataSetFactory.getNewDataSet();
            }
        }
    }

    public void setTrainingMethod(EnsembleTrainFactory newTrainFactory) {
        this.trainFactory = newTrainFactory;
        this.initMembers();
    }

    public void setTrainingData(MLDataSet data) {
        this.dataSetFactory.setInputData(data);
        this.initMembers();
    }

    public void setTrainingDataFactory(EnsembleDataSetFactory dataSetFactory) {
        this.dataSetFactory = dataSetFactory;
        this.initMembers();
    }

    public void trainMember(int index, double targetError, double selectionError, int maxIterations, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        EnsembleML current = this.members.get(index);
        this.trainMember(current, targetError, selectionError, maxIterations, 2000, selectionSet, verbose);
    }

    public void trainMember(EnsembleML current, double targetError, double selectionError, int maxIterations, int maxLoops, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        int attempt = 0;
        do {
            long startTime = System.nanoTime();
            this.mlFactory.reInit(current.getMl());
            current.train(targetError, maxIterations, verbose);
            long endTime = System.nanoTime();
            if (verbose) {
                System.out.println("training took " + (double)(endTime - startTime) / 1.0E9);
                System.out.println("test MSE: " + current.getError(selectionSet) + " on " + selectionSet.size() + " data points");
            }
            if (++attempt <= maxLoops) continue;
            throw new TrainingAborted("Too many attempts at training ensemble member");
        } while (current.getError(selectionSet) > selectionError);
    }

    public void trainMember(EnsembleML current, double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        this.trainMember(current, targetError, selectionError, 2000, 2000, selectionSet, verbose);
    }

    public void trainMember(int index, double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        this.trainMember(index, targetError, selectionError, 2000, selectionSet, verbose);
    }

    public void retrainAggregator() {
        EnsembleDataSet aggTrainingSet = new EnsembleDataSet(this.members.size() * this.aggregatorDataSet.getIdealSize(), this.aggregatorDataSet.getIdealSize());
        for (MLDataPair trainingInput : this.aggregatorDataSet) {
            BasicMLData trainingInstance = new BasicMLData(this.members.size() * this.aggregatorDataSet.getIdealSize());
            int index = 0;
            for (EnsembleML member : this.members) {
                for (double val : member.compute(trainingInput.getInput()).getData()) {
                    trainingInstance.add(index++, val);
                }
            }
            aggTrainingSet.add(trainingInstance, trainingInput.getIdeal());
        }
        this.aggregator.setTrainingSet(aggTrainingSet);
        this.aggregator.train();
    }

    public void train(double targetError, double selectionError, int maxIterations, int maxLoops, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        for (EnsembleML current : this.members) {
            this.trainMember(current, targetError, selectionError, maxIterations, maxLoops, selectionSet, verbose);
        }
        if (this.aggregator.needsTraining()) {
            this.retrainAggregator();
        }
    }

    public void train(double targetError, double selectionError, EnsembleDataSet selectionSet, boolean verbose) throws TrainingAborted {
        this.train(targetError, selectionError, 2000, 2000, selectionSet, verbose);
    }

    public void train(double targetError, double selectionError, EnsembleDataSet testset) throws TrainingAborted {
        this.train(targetError, selectionError, testset, false);
    }

    public void train(double targetError, double selectionError, int maxIterations, EnsembleDataSet testset) throws TrainingAborted {
        this.train(targetError, selectionError, maxIterations, 2000, testset, false);
    }

    public MLDataSet getTrainingSet(int setNumber) {
        return this.members.get(setNumber).getTrainingSet();
    }

    public EnsembleML getMember(int memberNumber) {
        return this.members.get(memberNumber);
    }

    public void addMember(EnsembleML newMember) throws NotPossibleInThisMethod {
        this.members.add(newMember);
    }

    public MLData compute(MLData input) throws WeightedAveraging.WeightMismatchException {
        ArrayList<MLData> outputs = new ArrayList<MLData>();
        for (EnsembleML member : this.members) {
            MLData computed = member.compute(input);
            outputs.add(computed);
        }
        return this.aggregator.evaluate(outputs);
    }

    public EnsembleAggregator getAggregator() {
        return this.aggregator;
    }

    public void setAggregator(EnsembleAggregator aggregator) {
        this.aggregator = aggregator;
    }

    public abstract EnsembleTypes.ProblemType getProblemType();

    public class TrainingAborted
    extends Exception {
        private static final long serialVersionUID = -5074472788684621859L;

        public TrainingAborted(String string) {
            super(string);
        }
    }

    public class NotPossibleInThisMethod
    extends Exception {
        private static final long serialVersionUID = 5118253806179408868L;
    }
}

