/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.MEMM;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.BitSet;
import java.util.logging.Logger;

public class MEMMTrainer
extends TransducerTrainer {
    private static Logger logger = MalletLogger.getLogger(MEMMTrainer.class.getName());
    MEMM memm;
    private boolean gatheringTrainingData = false;
    private InstanceList trainingGatheredFor;
    MEMMOptimizableByLabelLikelihood omemm;

    public MEMMTrainer(MEMM memm) {
        this.memm = memm;
    }

    public MEMMOptimizableByLabelLikelihood getOptimizableMEMM(InstanceList trainingSet) {
        return new MEMMOptimizableByLabelLikelihood(this.memm, trainingSet);
    }

    @Override
    public boolean train(InstanceList training) {
        return this.train(training, Integer.MAX_VALUE);
    }

    @Override
    public boolean train(InstanceList training, int numIterations) {
        if (numIterations <= 0) {
            return false;
        }
        assert (training.size() > 0);
        if (this.trainingGatheredFor != training) {
            this.gatherTrainingSets(training);
        }
        this.omemm = new MEMMOptimizableByLabelLikelihood(this.memm, training);
        this.omemm.gatherExpectationsOrConstraints(true);
        LimitedMemoryBFGS maximizer = new LimitedMemoryBFGS(this.omemm);
        boolean converged = false;
        logger.info("CRF about to train with " + numIterations + " iterations");
        for (int i = 0; i < numIterations; ++i) {
            try {
                converged = maximizer.optimize(1);
                logger.info("CRF finished one iteration of maximizer, i=" + i);
                this.runEvaluators();
            }
            catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            if (!converged) continue;
            logger.info("CRF training has converged, i=" + i);
            break;
        }
        logger.info("About to setTrainable(false)");
        return converged;
    }

    void gatherTrainingSets(InstanceList training) {
        if (this.trainingGatheredFor != null) {
            throw new UnsupportedOperationException("Training with multiple sets not supported.");
        }
        this.trainingGatheredFor = training;
        for (int i = 0; i < training.size(); ++i) {
            Instance instance = (Instance)training.get(i);
            FeatureVectorSequence input = (FeatureVectorSequence)instance.getData();
            FeatureSequence output = (FeatureSequence)instance.getTarget();
            new SumLatticeDefault(this.memm, input, output, new Transducer.Incrementor(){

                @Override
                public void incrementFinalState(Transducer.State s, double count) {
                }

                @Override
                public void incrementInitialState(Transducer.State s, double count) {
                }

                @Override
                public void incrementTransition(Transducer.TransitionIterator ti, double count) {
                    MEMM.State source = (MEMM.State)ti.getSourceState();
                    if (count != 0.0) {
                        if (source.trainingSet == null) {
                            source.trainingSet = new InstanceList(null);
                        }
                        source.trainingSet.add(new Instance(ti.getInput(), ti.getOutput(), null, null), count);
                    }
                }
            });
        }
    }

    public boolean train(InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, int numIterationsPerProportion, double[] trainingProportions) {
        throw new UnsupportedOperationException();
    }

    public boolean trainWithFeatureInduction(InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions, String gainName) {
        throw new UnsupportedOperationException();
    }

    public void printInstanceLists() {
        for (int i = 0; i < this.memm.numStates(); ++i) {
            MEMM.State state = (MEMM.State)this.memm.getState(i);
            InstanceList training = state.trainingSet;
            System.out.println("State " + i + " : " + state.getName());
            if (training == null) {
                System.out.println("No data");
                continue;
            }
            for (int j = 0; j < training.size(); ++j) {
                Instance inst = (Instance)training.get(j);
                System.out.println("From : " + state.getName() + " To : " + inst.getTarget());
                System.out.println("Instance " + j);
                System.out.println(inst.getTarget());
                System.out.println(inst.getData());
            }
        }
    }

    @Override
    public int getIteration() {
        return 0;
    }

    @Override
    public Transducer getTransducer() {
        return this.memm;
    }

    @Override
    public boolean isFinishedTraining() {
        return false;
    }

    public class MEMMOptimizableByLabelLikelihood
    extends CRFOptimizableByLabelLikelihood
    implements Optimizable.ByGradientValue {
        BitSet infiniteValues;

        protected MEMMOptimizableByLabelLikelihood(MEMM memm, InstanceList trainingData) {
            super(memm, trainingData);
            this.infiniteValues = null;
            this.expectations = new CRF.Factors(memm);
            this.constraints = new CRF.Factors(memm);
        }

        protected double gatherExpectationsOrConstraints(boolean gatherConstraints) {
            int i;
            boolean initializingInfiniteValues = false;
            CRF.Factors factors = gatherConstraints ? this.constraints : this.expectations;
            CRF.Factors.Incrementor factorIncrementor = new CRF.Factors.Incrementor(factors);
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                initializingInfiniteValues = true;
            }
            double labelLogProb = 0.0;
            for (i = 0; i < MEMMTrainer.this.memm.numStates(); ++i) {
                MEMM.State s = (MEMM.State)MEMMTrainer.this.memm.getState(i);
                if (s.trainingSet == null) {
                    System.out.println("Empty training set for state " + s.name);
                    continue;
                }
                for (int j = 0; j < s.trainingSet.size(); ++j) {
                    Instance instance = (Instance)s.trainingSet.get(j);
                    double instWeight = s.trainingSet.getInstanceWeight(j);
                    FeatureVector fv = (FeatureVector)instance.getData();
                    String labelString = (String)instance.getTarget();
                    MEMM.TransitionIterator iter = new MEMM.TransitionIterator(s, fv, gatherConstraints ? labelString : null, (CRF)MEMMTrainer.this.memm);
                    while (iter.hasNext()) {
                        iter.nextState();
                        double weight = iter.getWeight();
                        factorIncrementor.incrementTransition(iter, Math.exp(weight) * instWeight);
                        if (gatherConstraints || iter.getOutput() != labelString) continue;
                        if (!Double.isInfinite(weight)) {
                            labelLogProb += instWeight * weight;
                            continue;
                        }
                        logger.warning("State " + i + " transition " + j + " has infinite cost; skipping.");
                        if (initializingInfiniteValues) {
                            throw new IllegalStateException("Infinite-cost transitions not yet supported");
                        }
                        if (this.infiniteValues.get(j)) continue;
                        throw new IllegalStateException("Instance i used to have non-infinite value, but now it has infinite value.");
                    }
                }
            }
            for (i = 0; i < MEMMTrainer.this.memm.numStates(); ++i) {
                factors.initialWeights[i] = 0.0;
                factors.finalWeights[i] = 0.0;
            }
            return labelLogProb;
        }

        @Override
        protected double getExpectationValue() {
            return this.gatherExpectationsOrConstraints(false);
        }
    }
}

