/*
 * Decompiled with CFR 0.152.
 */
package org.encog.app.analyst.commands;

import java.io.File;
import org.encog.app.analyst.AnalystError;
import org.encog.app.analyst.EncogAnalyst;
import org.encog.app.analyst.commands.Cmd;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.folded.FoldedDataSet;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.training.cross.CrossValidationKFold;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.logging.EncogLogging;
import org.encog.util.simple.EncogUtility;
import org.encog.util.validate.ValidateNetwork;

public class CmdTrain
extends Cmd {
    public static final String COMMAND_NAME = "TRAIN";
    private int kfold;

    public CmdTrain(EncogAnalyst analyst) {
        super(analyst);
    }

    private MLTrain createTrainer(MLMethod method, MLDataSet trainingSet) {
        MLTrainFactory factory = new MLTrainFactory();
        String type = this.getProp().getPropertyString("ML:TRAIN_type");
        String args = this.getProp().getPropertyString("ML:TRAIN_arguments");
        EncogLogging.log(0, "training type:" + type);
        EncogLogging.log(0, "training args:" + args);
        if (method instanceof MLResettable) {
            this.getAnalyst().setMethod(method);
        }
        MLTrain train = factory.create(method, trainingSet, type, args);
        if (this.kfold > 0) {
            train = new CrossValidationKFold(train, this.kfold);
        }
        return train;
    }

    @Override
    public final boolean executeCommand(String args) {
        this.kfold = this.obtainCross();
        MLDataSet trainingSet = this.obtainTrainingSet();
        MLMethod method = this.obtainMethod();
        MLTrain trainer = this.createTrainer(method, trainingSet);
        if (method instanceof BayesianNetwork) {
            String query = this.getProp().getPropertyString("ML:CONFIG_query");
            ((BayesianNetwork)method).defineClassificationStructure(query);
        }
        EncogLogging.log(0, "Beginning training");
        this.performTraining(trainer, method, trainingSet);
        String resourceID = this.getProp().getPropertyString("ML:CONFIG_machineLearningFile");
        File resourceFile = this.getAnalyst().getScript().resolveFilename(resourceID);
        method = trainer.getMethod();
        EncogDirectoryPersistence.saveObject(resourceFile, (Object)method);
        EncogLogging.log(0, "save to:" + resourceID);
        trainingSet.close();
        return this.getAnalyst().shouldStopCommand();
    }

    @Override
    public final String getName() {
        return COMMAND_NAME;
    }

    private int obtainCross() {
        String cross = this.getProp().getPropertyString("ML:TRAIN_cross");
        if (cross == null || cross.length() == 0) {
            return 0;
        }
        if (cross.toLowerCase().startsWith("kfold:")) {
            String str = cross.substring(6);
            try {
                return Integer.parseInt(str);
            }
            catch (NumberFormatException ex) {
                throw new AnalystError("Invalid kfold :" + str);
            }
        }
        throw new AnalystError("Unknown cross validation: " + cross);
    }

    private MLMethod obtainMethod() {
        String resourceID = this.getProp().getPropertyString("ML:CONFIG_machineLearningFile");
        File resourceFile = this.getScript().resolveFilename(resourceID);
        MLMethod method = (MLMethod)EncogDirectoryPersistence.loadObject(resourceFile);
        if (!(method instanceof MLMethod)) {
            throw new AnalystError("The object to be trained must be an instance of MLMethod. " + method.getClass().getSimpleName());
        }
        return method;
    }

    private MLDataSet obtainTrainingSet() {
        String trainingID = this.getProp().getPropertyString("ML:CONFIG_trainingFile");
        File trainingFile = this.getScript().resolveFilename(trainingID);
        MLDataSet trainingSet = EncogUtility.loadEGB2Memory(trainingFile);
        if (this.kfold > 0) {
            trainingSet = new FoldedDataSet(trainingSet);
        }
        return trainingSet;
    }

    private void performTraining(MLTrain train, MLMethod method, MLDataSet trainingSet) {
        ValidateNetwork.validateMethodToData(method, trainingSet);
        double targetError = this.getProp().getPropertyDouble("ML:TRAIN_targetError");
        this.getAnalyst().reportTrainingBegin();
        int maxIteration = this.getAnalyst().getMaxIteration();
        if (train.getImplementationType() == TrainingImplementationType.OnePass) {
            train.iteration();
            this.getAnalyst().reportTraining(train);
        } else {
            do {
                train.iteration();
                this.getAnalyst().reportTraining(train);
            } while (train.getError() > targetError && !this.getAnalyst().shouldStopCommand() && !train.isTrainingDone() && (maxIteration == -1 || train.getIteration() < maxIteration));
        }
        train.finishTraining();
        this.getAnalyst().reportTrainingEnd();
        this.getAnalyst().setMethod(train.getMethod());
    }
}

