/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.logging.Logger;

public class DMROptimizable
implements Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(DMROptimizable.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(String.valueOf(DMROptimizable.class.getName()) + "-pl");
    MaxEnt classifier;
    InstanceList trainingList;
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;
    int numIterations = Integer.MAX_VALUE;
    NumberFormat formatter = null;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_LARGE_GAUSSIAN_PRIOR_VARIANCE = 100.0;
    static final double DEFAULT_GAUSSIAN_PRIOR_MEAN = 0.0;
    double gaussianPriorMean = 0.0;
    double gaussianPriorVariance = 1.0;
    double defaultFeatureGaussianPriorVariance = 100.0;
    double[] parameters;
    double[] cachedGradient;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;

    public DMROptimizable() {
    }

    public DMROptimizable(InstanceList instances, MaxEnt initialClassifier) {
        this.trainingList = instances;
        Alphabet alphabet = instances.getDataAlphabet();
        Alphabet labelAlphabet = instances.getTargetAlphabet();
        this.numLabels = labelAlphabet.size();
        this.numFeatures = alphabet.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        if (initialClassifier != null) {
            this.classifier = initialClassifier;
            this.parameters = this.classifier.getParameters();
            this.defaultFeatureIndex = this.classifier.getDefaultFeatureIndex();
            assert (initialClassifier.getInstancePipe() == instances.getPipe());
        } else if (this.classifier == null) {
            this.classifier = new MaxEnt(instances.getPipe(), this.parameters);
        }
        this.formatter = new DecimalFormat("0.###E0");
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            FeatureVector features = (FeatureVector)instance.getData();
            assert (features.getAlphabet() == alphabet);
            boolean hasNaN = false;
            int i = 0;
            while (i < features.numLocations()) {
                if (Double.isNaN(features.valueAtLocation(i))) {
                    logger.info("NaN for feature " + alphabet.lookupObject(features.indexAtLocation(i)).toString());
                    hasNaN = true;
                }
                ++i;
            }
            if (!hasNaN) continue;
            logger.info("NaN in instance: " + instance.getName());
        }
    }

    public void setInterceptGaussianPriorVariance(double sigmaSquared) {
        this.defaultFeatureGaussianPriorVariance = sigmaSquared;
    }

    public void setRegularGaussianPriorVariance(double sigmaSquared) {
        this.gaussianPriorVariance = sigmaSquared;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void setParameter(int index, double v) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[index] = v;
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public void getParameters(double[] buff) {
        if (buff == null || buff.length != this.parameters.length) {
            buff = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, buff, 0, this.parameters.length);
    }

    @Override
    public void setParameters(double[] buff) {
        assert (buff != null);
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (buff.length != this.parameters.length) {
            this.parameters = new double[buff.length];
        }
        System.arraycopy(buff, 0, this.parameters, 0, buff.length);
    }

    @Override
    public double getValue() {
        if (!this.cachedValueStale) {
            return this.cachedValue;
        }
        ++this.numGetValueCalls;
        this.cachedValue = 0.0;
        double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
        double value = 0.0;
        int instanceIndex = 0;
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            this.classifier.getUnnormalizedClassificationScores(instance, scores);
            double sumScores = 0.0;
            int i = 0;
            while (i < scores.length) {
                scores[i] = Math.exp(scores[i]);
                sumScores += scores[i];
                ++i;
            }
            FeatureVector features = (FeatureVector)instance.getData();
            double totalLength = 0.0;
            int i2 = 0;
            while (i2 < multinomialValues.numLocations()) {
                int label = multinomialValues.indexAtLocation(i2);
                double count = multinomialValues.valueAtLocation(i2);
                value += Dirichlet.logGammaStirling(scores[label] + count) - Dirichlet.logGammaStirling(scores[label]);
                totalLength += count;
                ++i2;
            }
            if (Double.isNaN(value -= Dirichlet.logGammaStirling(sumScores + totalLength) - Dirichlet.logGammaStirling(sumScores))) {
                logger.fine("DCMMaxEntTrainer: Instance " + instance.getName() + "has NaN value.");
                int[] nArray = multinomialValues.getIndices();
                int n = nArray.length;
                int n2 = 0;
                while (n2 < n) {
                    int label = nArray[n2];
                    logger.fine("log(scores)= " + Math.log(scores[label]) + " scores = " + scores[label]);
                    ++n2;
                }
            }
            if (Double.isInfinite(value)) {
                logger.warning("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                this.cachedValue -= value;
                this.cachedValueStale = false;
                return -value;
            }
            this.cachedValue += value;
            ++instanceIndex;
        }
        double prior = 0.0;
        int label = 0;
        while (label < this.numLabels) {
            int feature = 0;
            while (feature < this.numFeatures - 1) {
                double param = this.parameters[label * this.numFeatures + feature];
                prior -= (param - this.gaussianPriorMean) * (param - this.gaussianPriorMean) / (2.0 * this.gaussianPriorVariance);
                ++feature;
            }
            double param = this.parameters[label * this.numFeatures + this.defaultFeatureIndex];
            prior -= (param - this.gaussianPriorMean) * (param - this.gaussianPriorMean) / (2.0 * this.defaultFeatureGaussianPriorVariance);
            ++label;
        }
        double labelProbability = this.cachedValue;
        this.cachedValue += prior;
        this.cachedValueStale = false;
        progressLogger.info("Value (likelihood=" + this.formatter.format(labelProbability) + " prior=" + this.formatter.format(prior) + ") = " + this.formatter.format(this.cachedValue));
        return this.cachedValue;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        MatrixOps.setAll(this.cachedGradient, 0.0);
        double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
        boolean instanceIndex = false;
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            this.classifier.getUnnormalizedClassificationScores(instance, scores);
            double sumScores = 0.0;
            int i = 0;
            while (i < scores.length) {
                scores[i] = Math.exp(scores[i]);
                sumScores += scores[i];
                ++i;
            }
            FeatureVector features = (FeatureVector)instance.getData();
            double totalLength = 0.0;
            double[] dArray = multinomialValues.getValues();
            int n = dArray.length;
            int n2 = 0;
            while (n2 < n) {
                double count = dArray[n2];
                totalLength += count;
                ++n2;
            }
            double digammaDifferenceForSums = Dirichlet.digamma(sumScores + totalLength) - Dirichlet.digamma(sumScores);
            int loc = 0;
            while (loc < features.numLocations()) {
                int index = features.indexAtLocation(loc);
                double value = features.valueAtLocation(loc);
                if (value != 0.0) {
                    int label = 0;
                    while (label < this.numLabels) {
                        int n3 = label * this.numFeatures + index;
                        this.cachedGradient[n3] = this.cachedGradient[n3] - value * scores[label] * digammaDifferenceForSums;
                        ++label;
                    }
                    int labelLoc = 0;
                    while (labelLoc < multinomialValues.numLocations()) {
                        int label2 = multinomialValues.indexAtLocation(labelLoc);
                        double count = multinomialValues.valueAtLocation(labelLoc);
                        double diff = 0.0;
                        if (count < 20.0) {
                            int i2 = 0;
                            while ((double)i2 < count) {
                                diff += 1.0 / (scores[label2] + (double)i2);
                                ++i2;
                            }
                        } else {
                            diff = Dirichlet.digamma(scores[label2] + count) - Dirichlet.digamma(scores[label2]);
                        }
                        int n4 = label2 * this.numFeatures + index;
                        this.cachedGradient[n4] = this.cachedGradient[n4] + value * scores[label2] * diff;
                        ++labelLoc;
                    }
                }
                ++loc;
            }
            int label = 0;
            while (label < this.numLabels) {
                int n5 = label * this.numFeatures + this.defaultFeatureIndex;
                this.cachedGradient[n5] = this.cachedGradient[n5] - scores[label] * digammaDifferenceForSums;
                ++label;
            }
            int labelLoc = 0;
            while (labelLoc < multinomialValues.numLocations()) {
                int label3 = multinomialValues.indexAtLocation(labelLoc);
                double count = multinomialValues.valueAtLocation(labelLoc);
                double diff = 0.0;
                if (count < 20.0) {
                    int i3 = 0;
                    while ((double)i3 < count) {
                        diff += 1.0 / (scores[label3] + (double)i3);
                        ++i3;
                    }
                } else {
                    diff = Dirichlet.digamma(scores[label3] + count) - Dirichlet.digamma(scores[label3]);
                }
                int n6 = label3 * this.numFeatures + this.defaultFeatureIndex;
                this.cachedGradient[n6] = this.cachedGradient[n6] + scores[label3] * diff;
                ++labelLoc;
            }
        }
        ++this.numGetValueGradientCalls;
        int label = 0;
        while (label < this.numLabels) {
            int feature = 0;
            while (feature < this.numFeatures - 1) {
                double param = this.parameters[label * this.numFeatures + feature];
                int n = label * this.numFeatures + feature;
                this.cachedGradient[n] = this.cachedGradient[n] - (param - this.gaussianPriorMean) / this.gaussianPriorVariance;
                ++feature;
            }
            double param = this.parameters[label * this.numFeatures + this.defaultFeatureIndex];
            int n = label * this.numFeatures + this.defaultFeatureIndex;
            this.cachedGradient[n] = this.cachedGradient[n] - (param - this.gaussianPriorMean) / this.defaultFeatureGaussianPriorVariance;
            ++label;
        }
        MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0);
        assert (buffer != null && buffer.length == this.parameters.length);
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }
}

