/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.HMM;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

public class HMMTrainerByLikelihood
extends TransducerTrainer {
    private static Logger logger = MalletLogger.getLogger(HMMTrainerByLikelihood.class.getName());
    HMM hmm;
    InstanceList trainingSet;
    InstanceList unlabeledSet;
    int iterationCount = 0;
    boolean converged = false;

    public HMMTrainerByLikelihood(HMM hmm) {
        this.hmm = hmm;
    }

    @Override
    public Transducer getTransducer() {
        return this.hmm;
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        return this.train(trainingSet, null, numIterations);
    }

    public boolean train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) {
        if (this.hmm.emissionEstimator == null) {
            this.hmm.reset();
        }
        this.converged = false;
        double threshold = 0.001;
        double logLikelihood = Double.NEGATIVE_INFINITY;
        int iter = 0;
        while (iter < numIterations) {
            double prevLogLikelihood = logLikelihood;
            logLikelihood = 0.0;
            for (Instance inst : trainingSet) {
                FeatureSequence input = (FeatureSequence)inst.getData();
                FeatureSequence output = (FeatureSequence)inst.getTarget();
                double obsLikelihood = new SumLatticeDefault(this.hmm, input, output, this.hmm.new HMM.Incrementor()).getTotalWeight();
                logLikelihood += obsLikelihood;
            }
            logger.info("getValue() (observed log-likelihood) = " + logLikelihood);
            if (unlabeledSet != null) {
                int numEx = 0;
                for (Instance inst : unlabeledSet) {
                    if (++numEx % 100 == 0) {
                        System.err.print(String.valueOf(numEx) + ". ");
                        System.err.flush();
                    }
                    FeatureSequence input = (FeatureSequence)inst.getData();
                    double hiddenLikelihood = new SumLatticeDefault(this.hmm, input, null, this.hmm.new HMM.Incrementor()).getTotalWeight();
                    logLikelihood += hiddenLikelihood;
                }
                System.err.println();
            }
            logger.info("getValue() (log-likelihood) = " + logLikelihood);
            this.hmm.estimate();
            ++this.iterationCount;
            logger.info("HMM finished one iteration of maximizer, i=" + iter);
            this.runEvaluators();
            if (Math.abs(logLikelihood - prevLogLikelihood) < threshold) {
                this.converged = true;
                logger.info("HMM training has converged, i=" + iter);
                break;
            }
            ++iter;
        }
        return this.converged;
    }
}

