/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.Random;
import java.util.logging.Logger;

public class CRFTrainerByThreadedLabelLikelihood
extends TransducerTrainer
implements TransducerTrainer.ByOptimization {
    private static Logger logger = MalletLogger.getLogger(CRFTrainerByThreadedLabelLikelihood.class.getName());
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    private boolean useSparseWeights;
    private boolean useNoWeights;
    private transient boolean useSomeUnsupportedTrick;
    private boolean converged;
    private int numThreads;
    private int iterationCount;
    private double gaussianPriorVariance;
    private CRF crf;
    private CRFOptimizableByBatchLabelLikelihood optimizable;
    private ThreadedOptimizable threadedOptimizable;
    private Optimizer optimizer;
    private int cachedWeightsStructureStamp;

    public CRFTrainerByThreadedLabelLikelihood(CRF crf, int numThreads) {
        this.crf = crf;
        this.useSparseWeights = true;
        this.useNoWeights = false;
        this.useSomeUnsupportedTrick = true;
        this.converged = false;
        this.numThreads = numThreads;
        this.iterationCount = 0;
        this.gaussianPriorVariance = 1.0;
        this.cachedWeightsStructureStamp = -1;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    public void setGaussianPriorVariance(double p) {
        this.gaussianPriorVariance = p;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    public void setUseSparseWeights(boolean b) {
        this.useSparseWeights = b;
    }

    public boolean getUseSparseWeights() {
        return this.useSparseWeights;
    }

    public void setUseSomeUnsupportedTrick(boolean b) {
        this.useSomeUnsupportedTrick = b;
    }

    public void setAddNoFactors(boolean flag) {
        this.useNoWeights = flag;
    }

    public void shutdown() {
        this.threadedOptimizable.shutdown();
    }

    public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF(InstanceList trainingSet) {
        if (this.cachedWeightsStructureStamp != this.crf.weightsStructureChangeStamp) {
            if (!this.useNoWeights) {
                if (this.useSparseWeights) {
                    this.crf.setWeightsDimensionAsIn(trainingSet, this.useSomeUnsupportedTrick);
                } else {
                    this.crf.setWeightsDimensionDensely();
                }
            }
            this.optimizable = null;
            this.cachedWeightsStructureStamp = this.crf.weightsStructureChangeStamp;
        }
        if (this.optimizable == null || this.optimizable.trainingSet != trainingSet) {
            this.optimizable = new CRFOptimizableByBatchLabelLikelihood(this.crf, trainingSet, this.numThreads);
            this.optimizable.setGaussianPriorVariance(this.gaussianPriorVariance);
            this.threadedOptimizable = new ThreadedOptimizable(this.optimizable, trainingSet, this.crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(this.crf));
            this.optimizer = null;
        }
        return this.optimizable;
    }

    public Optimizer getOptimizer(InstanceList trainingSet) {
        this.getOptimizableCRF(trainingSet);
        if (this.optimizer == null || this.optimizable != this.optimizer.getOptimizable()) {
            this.optimizer = new LimitedMemoryBFGS(this.threadedOptimizable);
        }
        return this.optimizer;
    }

    public boolean trainIncremental(InstanceList training) {
        return this.train(training, Integer.MAX_VALUE);
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        if (numIterations <= 0) {
            return false;
        }
        assert (trainingSet.size() > 0);
        this.getOptimizableCRF(trainingSet);
        this.getOptimizer(trainingSet);
        boolean converged = false;
        logger.info("CRF about to train with " + numIterations + " iterations");
        int i = 0;
        while (i < numIterations) {
            try {
                converged = this.optimizer.optimize(1);
                ++this.iterationCount;
                logger.info("CRF finished one iteration of maximizer, i=" + i);
                this.runEvaluators();
            }
            catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            catch (Exception e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            if (converged) {
                logger.info("CRF training has converged, i=" + i);
                break;
            }
            ++i;
        }
        return converged;
    }

    public boolean train(InstanceList training, int numIterationsPerProportion, double[] trainingProportions) {
        int trainingIteration = 0;
        assert (trainingProportions.length > 0);
        boolean converged = false;
        int i = 0;
        while (i < trainingProportions.length) {
            assert (trainingProportions[i] <= 1.0);
            logger.info("Training on " + trainingProportions[i] + "% of the data this round.");
            converged = trainingProportions[i] == 1.0 ? this.train(training, numIterationsPerProportion) : this.train(training.split(new Random(1L), new double[]{trainingProportions[i], 1.0 - trainingProportions[i]})[0], numIterationsPerProportion);
            trainingIteration += numIterationsPerProportion;
            ++i;
        }
        return converged;
    }
}

