/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.PRAuxClassifier;
import cc.mallet.classify.constraints.pr.MaxEntPRConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

public class PRAuxClassifierOptimizable
implements Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(PRAuxClassifierOptimizable.class.getName());
    private boolean cacheStale;
    private int numParameters;
    private double cachedValue;
    private double[] cachedGradient;
    private double[][] parameters;
    private double[][] baseDist;
    private PRAuxClassifier classifier;
    private ArrayList<MaxEntPRConstraint> constraints;
    private InstanceList trainingData;

    public PRAuxClassifierOptimizable(InstanceList trainingData, double[][] baseDistribution, PRAuxClassifier classifier) {
        this.trainingData = trainingData;
        this.baseDist = baseDistribution;
        this.classifier = classifier;
        this.parameters = classifier.getParameters();
        this.constraints = classifier.getConstraintFeatures();
        this.numParameters = 0;
        int i = 0;
        while (i < this.parameters.length) {
            this.numParameters += this.parameters[i].length;
            ++i;
        }
        this.cachedValue = Double.NEGATIVE_INFINITY;
        this.cachedGradient = new double[this.numParameters];
        this.cacheStale = true;
    }

    @Override
    public int getNumParameters() {
        return this.numParameters;
    }

    @Override
    public void getParameters(double[] buffer) {
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            System.arraycopy(this.parameters[i], 0, buffer, start, this.parameters[i].length);
            start += this.parameters[i].length;
            ++i;
        }
    }

    @Override
    public double getParameter(int index) {
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            if (start < this.parameters[i].length) {
                return this.parameters[i][start];
            }
            start -= this.parameters[i].length;
            ++i;
        }
        throw new RuntimeException(String.valueOf(index) + " out of bounds.");
    }

    @Override
    public void setParameters(double[] params) {
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            System.arraycopy(params, start, this.parameters[i], 0, this.parameters[i].length);
            start += this.parameters[i].length;
            ++i;
        }
        this.cacheStale = true;
    }

    @Override
    public void setParameter(int index, double value) {
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            if (start < this.parameters[i].length) {
                this.parameters[i][start] = value;
            } else {
                start -= this.parameters[i].length;
            }
            ++i;
        }
        this.cacheStale = true;
    }

    public double getValueAndGradient(double[] gradient) {
        Arrays.fill(gradient, 0.0);
        this.classifier.zeroExpectations();
        int numLabels = this.trainingData.getTargetAlphabet().size();
        double value = 0.0;
        int ii = 0;
        while (ii < this.trainingData.size()) {
            double[] scores = new double[numLabels];
            Instance instance = (Instance)this.trainingData.get(ii);
            FeatureVector input = (FeatureVector)instance.getData();
            double instanceWeight = this.trainingData.getInstanceWeight(ii);
            this.classifier.getClassificationScores(instance, scores);
            double logZ = Double.NEGATIVE_INFINITY;
            int li = 0;
            while (li < numLabels) {
                if (this.baseDist != null && this.baseDist[ii][li] == 0.0) {
                    scores[li] = Double.NEGATIVE_INFINITY;
                } else if (this.baseDist != null) {
                    double logP = Math.log(this.baseDist[ii][li]);
                    int n = li;
                    scores[n] = scores[n] + logP;
                }
                logZ = Maths.sumLogProb(logZ, scores[li]);
                ++li;
            }
            assert (!Double.isNaN(logZ));
            if (Double.isNaN(value -= instanceWeight * logZ)) {
                logger.warning("Instance " + instance.getName() + " has NaN value.");
            } else if (Double.isInfinite(value)) {
                logger.warning("Instance " + instance.getName() + " has infinite value; skipping value and gradient");
            } else {
                MatrixOps.expNormalize(scores);
                for (MaxEntPRConstraint constraint : this.constraints) {
                    constraint.incrementExpectations(input, scores, 1.0);
                }
            }
            ++ii;
        }
        int ci = 0;
        int start = 0;
        for (MaxEntPRConstraint constraint : this.constraints) {
            double[] temp = new double[this.parameters[ci].length];
            value += constraint.getAuxiliaryValueContribution(this.parameters[ci]);
            constraint.getGradient(this.parameters[ci], temp);
            System.arraycopy(temp, 0, gradient, start, temp.length);
            start += temp.length;
            ++ci;
        }
        logger.info("PR auxiliary value = " + value);
        return value;
    }

    @Override
    public double getValue() {
        if (this.cacheStale) {
            this.cachedValue = this.getValueAndGradient(this.cachedGradient);
            this.cacheStale = false;
        }
        return this.cachedValue;
    }

    @Override
    public void getValueGradient(double[] gradient) {
        if (this.cacheStale) {
            this.cachedValue = this.getValueAndGradient(this.cachedGradient);
            this.cacheStale = false;
        }
        System.arraycopy(this.cachedGradient, 0, gradient, 0, gradient.length);
    }
}

