/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.util.ArrayList;
import java.util.Collections;

public class CRFTrainerByStochasticGradient
extends TransducerTrainer.ByInstanceIncrements {
    protected CRF crf;
    protected double learningRate;
    protected double t;
    protected double lambda;
    protected int iterationCount = 0;
    protected boolean converged = false;
    protected CRF.Factors expectations;
    protected CRF.Factors constraints;

    public CRFTrainerByStochasticGradient(CRF crf, InstanceList trainingSample) {
        this.crf = crf;
        this.expectations = new CRF.Factors(crf);
        this.constraints = new CRF.Factors(crf);
        this.setLearningRateByLikelihood(trainingSample);
    }

    public CRFTrainerByStochasticGradient(CRF crf, double learningRate) {
        this.crf = crf;
        this.learningRate = learningRate;
        this.expectations = new CRF.Factors(crf);
        this.constraints = new CRF.Factors(crf);
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setLearningRateByLikelihood(InstanceList trainingSample) {
        int numIterations = 5;
        double bestLearningRate = Double.NEGATIVE_INFINITY;
        double bestLikelihoodChange = Double.NEGATIVE_INFINITY;
        double currLearningRate = 5.0E-11;
        while (currLearningRate < 1.0) {
            this.crf.parameters.zero();
            double beforeLikelihood = this.computeLikelihood(trainingSample);
            double likelihoodChange = this.trainSample(trainingSample, numIterations, currLearningRate *= 2.0) - beforeLikelihood;
            System.out.println("likelihood change = " + likelihoodChange + " for learningrate=" + currLearningRate);
            if (!(likelihoodChange > bestLikelihoodChange)) continue;
            bestLikelihoodChange = likelihoodChange;
            bestLearningRate = currLearningRate;
        }
        this.crf.parameters.zero();
        System.out.println("Setting learning rate to " + (bestLearningRate /= 2.0));
        this.setLearningRate(bestLearningRate);
    }

    private double trainSample(InstanceList trainingSample, int numIterations, double rate) {
        double lambda = trainingSample.size();
        double t = 1.0 / (lambda * rate);
        double loglik = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < numIterations) {
            loglik = 0.0;
            int j = 0;
            while (j < trainingSample.size()) {
                rate = 1.0 / (lambda * t);
                loglik += this.trainIncrementalLikelihood((Instance)trainingSample.get(j), rate);
                t += 1.0;
                ++j;
            }
            ++i;
        }
        return loglik;
    }

    private double computeLikelihood(InstanceList trainingSample) {
        double loglik = 0.0;
        int i = 0;
        while (i < trainingSample.size()) {
            Instance trainingInstance = (Instance)trainingSample.get(i);
            FeatureVectorSequence fvs = (FeatureVectorSequence)trainingInstance.getData();
            Sequence labelSequence = (Sequence)trainingInstance.getTarget();
            loglik += new SumLatticeDefault(this.crf, fvs, labelSequence, null).getTotalWeight();
            loglik -= new SumLatticeDefault(this.crf, fvs, null, null).getTotalWeight();
            ++i;
        }
        this.constraints.zero();
        this.expectations.zero();
        return loglik;
    }

    public void setLearningRate(double r) {
        this.learningRate = r;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        return this.train(trainingSet, numIterations, 1);
    }

    public boolean train(InstanceList trainingSet, int numIterations, int numIterationsBetweenEvaluation) {
        assert (this.expectations.structureMatches(this.crf.parameters));
        assert (this.constraints.structureMatches(this.crf.parameters));
        this.lambda = 1.0 / (double)trainingSet.size();
        this.t = 1.0 / (this.lambda * this.learningRate);
        this.converged = false;
        ArrayList<Integer> trainingIndices = new ArrayList<Integer>();
        int i = 0;
        while (i < trainingSet.size()) {
            trainingIndices.add(i);
            ++i;
        }
        double oldLoglik = Double.NEGATIVE_INFINITY;
        while (numIterations-- > 0) {
            ++this.iterationCount;
            Collections.shuffle(trainingIndices);
            double loglik = 0.0;
            int i2 = 0;
            while (i2 < trainingSet.size()) {
                this.learningRate = 1.0 / (this.lambda * this.t);
                loglik += this.trainIncrementalLikelihood((Instance)trainingSet.get((Integer)trainingIndices.get(i2)));
                this.t += 1.0;
                ++i2;
            }
            System.out.println("loglikelihood[" + numIterations + "] = " + loglik);
            if (Math.abs(loglik - oldLoglik) < 0.001) {
                this.converged = true;
                break;
            }
            oldLoglik = loglik;
            Runtime.getRuntime().gc();
            if (this.iterationCount % numIterationsBetweenEvaluation != 0) continue;
            this.runEvaluators();
        }
        return this.converged;
    }

    @Override
    public boolean trainIncremental(InstanceList trainingSet) {
        this.train(trainingSet, 1);
        return false;
    }

    @Override
    public boolean trainIncremental(Instance trainingInstance) {
        assert (this.expectations.structureMatches(this.crf.parameters));
        this.trainIncrementalLikelihood(trainingInstance);
        return false;
    }

    public double trainIncrementalLikelihood(Instance trainingInstance) {
        return this.trainIncrementalLikelihood(trainingInstance, this.learningRate);
    }

    public double trainIncrementalLikelihood(Instance trainingInstance, double rate) {
        this.constraints.zero();
        this.expectations.zero();
        FeatureVectorSequence fvs = (FeatureVectorSequence)trainingInstance.getData();
        Sequence labelSequence = (Sequence)trainingInstance.getTarget();
        double singleLoglik = new SumLatticeDefault(this.crf, fvs, labelSequence, new CRF.Factors.Incrementor(this.constraints)).getTotalWeight();
        this.constraints.plusEquals(this.expectations, -1.0);
        this.crf.parameters.plusEquals(this.constraints, rate, true);
        return singleLoglik -= new SumLatticeDefault(this.crf, fvs, null, new CRF.Factors.Incrementor(this.expectations)).getTotalWeight();
    }
}

