/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.semi_supervised.pr.PRAuxiliaryModel;
import cc.mallet.fst.semi_supervised.pr.SumLatticeDefaultCachedDot;
import cc.mallet.fst.semi_supervised.pr.SumLatticeKL;
import cc.mallet.fst.semi_supervised.pr.SumLatticePR;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

public class CRFOptimizableByKL
implements Serializable,
Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(CRFOptimizableByKL.class.getName());
    private static final long serialVersionUID = 1L;
    protected int cachedValueWeightsStamp;
    protected int cachedGradientWeightsStamp;
    protected int numParameters;
    protected int numThreads;
    protected double weight;
    protected double gaussianPriorVariance = 1.0;
    protected double cachedValue = -1.23456789E8;
    protected double[] cachedGradient;
    protected List<double[]> initialProbList;
    protected List<double[]> finalProbList;
    protected List<double[][][]> transitionProbList;
    protected InstanceList trainingSet;
    protected CRF crf;
    protected CRF.Factors constraints;
    protected CRF.Factors expectations;
    protected ThreadPoolExecutor executor;
    protected PRAuxiliaryModel auxModel;

    public CRFOptimizableByKL(CRF crf, InstanceList trainingSet, PRAuxiliaryModel auxModel, double[][][][] cachedDots, int numThreads, double weight) {
        this.crf = crf;
        this.trainingSet = trainingSet;
        this.numParameters = crf.getParameters().getNumFactors();
        this.cachedGradient = new double[this.numParameters];
        this.cachedValueWeightsStamp = -1;
        this.cachedGradientWeightsStamp = -1;
        assert (weight > 0.0);
        this.weight = weight;
        this.gatherConstraints(auxModel, cachedDots);
        this.numThreads = numThreads;
        this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
    }

    private double[] toProbabilities(double[] weights) {
        double[] probs = new double[weights.length];
        int i = 0;
        while (i < weights.length) {
            probs[i] = Math.exp(weights[i]);
            ++i;
        }
        MatrixOps.normalize(probs);
        return probs;
    }

    private void toProbabilities(double[][][] weights) {
        int i = 0;
        while (i < weights.length) {
            int j = 0;
            while (j < weights[i].length) {
                int k = 0;
                while (k < weights[i][j].length) {
                    weights[i][j][k] = Math.exp(weights[i][j][k]);
                    ++k;
                }
                ++j;
            }
            ++i;
        }
    }

    protected void gatherConstraints(PRAuxiliaryModel auxModel, double[][][][] cachedDots) {
        this.initialProbList = new ArrayList<double[]>();
        this.finalProbList = new ArrayList<double[]>();
        this.transitionProbList = new ArrayList<double[][][]>();
        this.constraints = new CRF.Factors(this.crf.getParameters());
        this.expectations = new CRF.Factors(this.crf.getParameters());
        this.constraints.zero();
        int ii = 0;
        while (ii < this.trainingSet.size()) {
            Instance inst = (Instance)this.trainingSet.get(ii);
            Sequence input = (Sequence)inst.getData();
            SumLatticePR geLatt = new SumLatticePR(this.crf, ii, input, null, auxModel, cachedDots[ii], false, null, null, true);
            double[][] gammas = geLatt.getGammas();
            double[] initialProbs = this.toProbabilities(gammas[0]);
            this.initialProbList.add(initialProbs);
            double[] finalProbs = this.toProbabilities(gammas[gammas.length - 1]);
            this.finalProbList.add(finalProbs);
            double[][][] transitionProbs = geLatt.getXis();
            this.toProbabilities(transitionProbs);
            this.transitionProbList.add(transitionProbs);
            new SumLatticeKL(this.crf, input, initialProbs, finalProbs, transitionProbs, null, new CRF.Factors.Incrementor(this.constraints));
            ++ii;
        }
    }

    protected double getExpectationValue() {
        this.expectations.zero();
        ArrayList<ExpectationTask> tasks = new ArrayList<ExpectationTask>();
        int increment = this.trainingSet.size() / this.numThreads;
        int start = 0;
        int end = increment;
        int taskIndex = 0;
        while (taskIndex < this.numThreads) {
            CRF.Factors exCopy = new CRF.Factors(this.expectations);
            tasks.add(new ExpectationTask(start, end, exCopy));
            start = end;
            end = taskIndex == this.numThreads - 2 ? this.trainingSet.size() : start + increment;
            ++taskIndex;
        }
        double value = 0.0;
        try {
            List list = this.executor.invokeAll(tasks);
            for (Future f : list) {
                try {
                    value += ((Double)f.get()).doubleValue();
                }
                catch (ExecutionException ee) {
                    ee.printStackTrace();
                }
            }
        }
        catch (InterruptedException interruptedException) {
            interruptedException.printStackTrace();
        }
        for (Callable callable : tasks) {
            this.expectations.plusEquals(((ExpectationTask)callable).getExpectationsCopy(), 1.0);
        }
        return value;
    }

    @Override
    public double getValue() {
        if (this.crf.getWeightsValueChangeStamp() != this.cachedValueWeightsStamp) {
            this.cachedValueWeightsStamp = this.crf.getWeightsValueChangeStamp();
            long startingTime = System.currentTimeMillis();
            this.cachedValue = this.getExpectationValue();
            double priorValue = this.crf.getParameters().gaussianPrior(this.gaussianPriorVariance);
            this.cachedValue += priorValue;
            logger.info("Gaussian prior = " + priorValue);
            this.cachedValue *= this.weight;
            assert (!Double.isNaN(this.cachedValue) && !Double.isInfinite(this.cachedValue)) : "Label likelihood is NaN/Infinite";
            logger.info("getValue() (loglikelihood, optimizable by klDiv) = " + this.cachedValue);
            long endingTime = System.currentTimeMillis();
            logger.fine("Inference milliseconds = " + (endingTime - startingTime));
        }
        return this.cachedValue;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cachedGradientWeightsStamp != this.crf.getWeightsValueChangeStamp()) {
            this.cachedGradientWeightsStamp = this.crf.getWeightsValueChangeStamp();
            this.getValue();
            this.expectations.plusEquals(this.constraints, -1.0);
            this.expectations.plusEqualsGaussianPriorGradient(this.crf.getParameters(), -this.gaussianPriorVariance);
            this.expectations.assertNotNaNOrInfinite();
            this.expectations.getParameters(this.cachedGradient);
            MatrixOps.timesEquals(this.cachedGradient, -this.weight);
        }
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }

    @Override
    public int getNumParameters() {
        return this.numParameters;
    }

    @Override
    public void getParameters(double[] buffer) {
        this.crf.getParameters().getParameters(buffer);
    }

    @Override
    public double getParameter(int index) {
        return this.crf.getParameters().getParameter(index);
    }

    @Override
    public void setParameters(double[] buff) {
        this.crf.getParameters().setParameters(buff);
        this.crf.weightsValueChanged();
    }

    @Override
    public void setParameter(int index, double value) {
        this.crf.getParameters().setParameter(index, value);
        this.crf.weightsValueChanged();
    }

    public void setGaussianPriorVariance(double value) {
        this.gaussianPriorVariance = value;
    }

    public void shutdown() {
        this.executor.shutdown();
    }

    private class ExpectationTask
    implements Callable<Double> {
        private int start;
        private int end;
        private CRF.Factors expectationsCopy;

        public ExpectationTask(int start, int end, CRF.Factors exCopy) {
            this.start = start;
            this.end = end;
            this.expectationsCopy = exCopy;
        }

        public CRF.Factors getExpectationsCopy() {
            return this.expectationsCopy;
        }

        @Override
        public Double call() throws Exception {
            double value = 0.0;
            int ii = this.start;
            while (ii < this.end) {
                Instance inst = (Instance)CRFOptimizableByKL.this.trainingSet.get(ii);
                Sequence input = (Sequence)inst.getData();
                double[] initProbs = CRFOptimizableByKL.this.initialProbList.get(ii);
                double[] finalProbs = CRFOptimizableByKL.this.finalProbList.get(ii);
                double[][][] transProbs = CRFOptimizableByKL.this.transitionProbList.get(ii);
                double[][][] cachedDots = new double[input.size()][CRFOptimizableByKL.this.crf.numStates()][CRFOptimizableByKL.this.crf.numStates()];
                int j = 0;
                while (j < input.size()) {
                    int k = 0;
                    while (k < CRFOptimizableByKL.this.crf.numStates()) {
                        int l = 0;
                        while (l < CRFOptimizableByKL.this.crf.numStates()) {
                            cachedDots[j][k][l] = Double.NEGATIVE_INFINITY;
                            ++l;
                        }
                        ++k;
                    }
                    ++j;
                }
                double labeledWeight = new SumLatticeKL(CRFOptimizableByKL.this.crf, input, initProbs, finalProbs, transProbs, cachedDots, null).getTotalWeight();
                value += labeledWeight;
                double unlabeledWeight = new SumLatticeDefaultCachedDot(CRFOptimizableByKL.this.crf, input, null, cachedDots, new CRF.Factors.Incrementor(this.expectationsCopy), false, null).getTotalWeight();
                value -= unlabeledWeight;
                ++ii;
            }
            return value;
        }
    }
}

