/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByLabelLikelihood;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.Serializable;
import java.util.logging.Logger;

public class MaxEntTrainer
extends ClassifierTrainer<MaxEnt>
implements ClassifierTrainer.ByOptimization<MaxEnt>,
Boostable,
Serializable {
    private static Logger logger = MalletLogger.getLogger(MaxEntTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(String.valueOf(MaxEntTrainer.class.getName()) + "-pl");
    int numIterations = Integer.MAX_VALUE;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_L1_WEIGHT = 0.0;
    static final Class DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    double gaussianPriorVariance = 1.0;
    double l1Weight = 0.0;
    Class maximizerClass = DEFAULT_MAXIMIZER_CLASS;
    InstanceList trainingSet = null;
    MaxEnt initialClassifier;
    MaxEntOptimizableByLabelLikelihood optimizable = null;
    Optimizer optimizer = null;

    public MaxEntTrainer() {
    }

    public MaxEntTrainer(MaxEnt theClassifierToTrain) {
        this.initialClassifier = theClassifierToTrain;
    }

    public MaxEntTrainer(double gaussianPriorVariance) {
        this.gaussianPriorVariance = gaussianPriorVariance;
    }

    @Override
    public MaxEnt getClassifier() {
        if (this.optimizable != null) {
            return this.optimizable.getClassifier();
        }
        return this.initialClassifier;
    }

    public void setClassifier(MaxEnt theClassifierToTrain) {
        assert (this.trainingSet == null || Alphabet.alphabetsMatch(theClassifierToTrain, this.trainingSet));
        if (this.initialClassifier != theClassifierToTrain) {
            this.initialClassifier = theClassifierToTrain;
            this.optimizable = null;
            this.optimizer = null;
        }
    }

    public Optimizable getOptimizable() {
        return this.optimizable;
    }

    public MaxEntOptimizableByLabelLikelihood getOptimizable(InstanceList trainingSet) {
        return this.getOptimizable(trainingSet, this.getClassifier());
    }

    public MaxEntOptimizableByLabelLikelihood getOptimizable(InstanceList trainingSet, MaxEnt initialClassifier) {
        if (trainingSet != this.trainingSet || this.initialClassifier != initialClassifier) {
            this.trainingSet = trainingSet;
            this.initialClassifier = initialClassifier;
            if (this.optimizable == null || this.optimizable.trainingList != trainingSet) {
                this.optimizable = new MaxEntOptimizableByLabelLikelihood(trainingSet, initialClassifier);
                if (this.l1Weight == 0.0) {
                    this.optimizable.setGaussianPriorVariance(this.gaussianPriorVariance);
                } else {
                    this.optimizable.useNoPrior();
                }
                this.optimizer = null;
            }
        }
        return this.optimizable;
    }

    @Override
    public Optimizer getOptimizer() {
        if (this.optimizer == null && this.optimizable != null) {
            this.optimizer = new ConjugateGradient(this.optimizable);
        }
        return this.optimizer;
    }

    public Optimizer getOptimizer(InstanceList trainingSet) {
        if (trainingSet != this.trainingSet || this.optimizable == null) {
            this.getOptimizable(trainingSet);
            this.optimizer = null;
        }
        if (this.optimizer == null) {
            this.optimizer = new LimitedMemoryBFGS(this.optimizable);
        }
        return this.optimizer;
    }

    public MaxEntTrainer setNumIterations(int i) {
        this.numIterations = i;
        return this;
    }

    @Override
    public int getIteration() {
        if (this.optimizable == null) {
            return 0;
        }
        return Integer.MAX_VALUE;
    }

    public MaxEntTrainer setGaussianPriorVariance(double gaussianPriorVariance) {
        this.gaussianPriorVariance = gaussianPriorVariance;
        return this;
    }

    public MaxEntTrainer setL1Weight(double l1Weight) {
        this.l1Weight = l1Weight;
        return this;
    }

    @Override
    public MaxEnt train(InstanceList trainingSet) {
        return this.train(trainingSet, this.numIterations);
    }

    @Override
    public MaxEnt train(InstanceList trainingSet, int numIterations) {
        logger.fine("trainingSet.size() = " + trainingSet.size());
        this.getOptimizer(trainingSet);
        int i = 0;
        while (i < numIterations) {
            try {
                this.finishedTraining = this.optimizer.optimize(1);
            }
            catch (InvalidOptimizableException e) {
                e.printStackTrace();
                logger.warning("Catching InvalidOptimizatinException! saying converged.");
                this.finishedTraining = true;
            }
            catch (OptimizationException e) {
                e.printStackTrace();
                logger.info("Catching OptimizationException; saying converged.");
                this.finishedTraining = true;
            }
            if (this.finishedTraining) break;
            ++i;
        }
        if (numIterations == Integer.MAX_VALUE) {
            this.optimizer = null;
            this.getOptimizer(trainingSet);
            try {
                this.finishedTraining = this.optimizer.optimize();
            }
            catch (InvalidOptimizableException e) {
                e.printStackTrace();
                logger.warning("Catching InvalidOptimizatinException! saying converged.");
                this.finishedTraining = true;
            }
            catch (OptimizationException e) {
                e.printStackTrace();
                logger.info("Catching OptimizationException; saying converged.");
                this.finishedTraining = true;
            }
        }
        progressLogger.info("\n");
        return this.optimizable.getClassifier();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("MaxEntTrainer");
        if (this.numIterations < Integer.MAX_VALUE) {
            builder.append(",numIterations=" + this.numIterations);
        }
        if (this.l1Weight != 0.0) {
            builder.append(",l1Weight=" + this.l1Weight);
        } else {
            builder.append(",gaussianPriorVariance=" + this.gaussianPriorVariance);
        }
        return builder.toString();
    }
}

