/*
 * Decompiled with CFR 0.152.
 */
package nak.quasinewton;

import java.util.ArrayList;
import java.util.Arrays;
import nak.data.DataIndexer;
import nak.data.OnePassRealValueDataIndexer;
import nak.quasinewton.DifferentiableFunction;

public class LogLikelihoodFunction
implements DifferentiableFunction {
    private int domainDimension;
    private double value;
    private double[] gradient;
    private double[] lastX;
    private double[] empiricalCount;
    private int numOutcomes;
    private int numFeatures;
    private int numContexts;
    private double[][] probModel;
    private String[] outcomeLabels;
    private String[] predLabels;
    private int[][] outcomePatterns;
    private final float[][] values;
    private final int[][] contexts;
    private final int[] outcomeList;
    private final int[] numTimesEventsSeen;

    public LogLikelihoodFunction(DataIndexer dataIndexer) {
        this.values = dataIndexer instanceof OnePassRealValueDataIndexer ? dataIndexer.getValues() : (float[][])null;
        this.contexts = dataIndexer.getContexts();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.predLabels = dataIndexer.getPredLabels();
        this.numOutcomes = dataIndexer.getOutcomeLabels().length;
        this.numFeatures = dataIndexer.getPredLabels().length;
        this.numContexts = this.contexts.length;
        this.domainDimension = this.numOutcomes * this.numFeatures;
        this.probModel = new double[this.numContexts][this.numOutcomes];
        this.gradient = null;
    }

    @Override
    public double valueAt(double[] dArray) {
        if (!this.checkLastX(dArray)) {
            this.calculate(dArray);
        }
        return this.value;
    }

    @Override
    public double[] gradientAt(double[] dArray) {
        if (!this.checkLastX(dArray)) {
            this.calculate(dArray);
        }
        return this.gradient;
    }

    @Override
    public int getDomainDimension() {
        return this.domainDimension;
    }

    public double[] getInitialPoint() {
        return new double[this.domainDimension];
    }

    public String[] getPredLabels() {
        return this.predLabels;
    }

    public String[] getOutcomeLabels() {
        return this.outcomeLabels;
    }

    public int[][] getOutcomePatterns() {
        return this.outcomePatterns;
    }

    private void calculate(double[] dArray) {
        int n;
        int n2;
        if (dArray.length != this.domainDimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to the function.");
        }
        this.initProbModel();
        if (this.empiricalCount == null) {
            this.initEmpCount();
        }
        double d = 0.0;
        for (int i = 0; i < this.numContexts; ++i) {
            int n3;
            double d2 = 0.0;
            for (int j = 0; j < this.contexts[i].length; ++j) {
                n2 = this.indexOf(this.outcomeList[i], this.contexts[i][j]);
                double d3 = 1.0;
                if (this.values != null) {
                    d3 = this.values[i][j];
                }
                if (d3 == 0.0) continue;
                d2 += d3 * dArray[n2];
            }
            this.probModel[i][this.outcomeList[i]] = Math.exp(d2);
            double d4 = 0.0;
            for (n3 = 0; n3 < this.numOutcomes; ++n3) {
                d4 += this.probModel[i][n3];
            }
            n3 = 0;
            while (n3 < this.numOutcomes) {
                double[] dArray2 = this.probModel[i];
                int n4 = n3++;
                dArray2[n4] = dArray2[n4] / d4;
            }
            for (n3 = 0; n3 < this.numTimesEventsSeen[i]; ++n3) {
                d += Math.log(this.probModel[i][this.outcomeList[i]]);
            }
        }
        this.value = d;
        double[] dArray3 = new double[this.numOutcomes * this.numFeatures];
        for (int i = 0; i < this.numContexts; ++i) {
            for (n = 0; n < this.numOutcomes; ++n) {
                for (int j = 0; j < this.contexts[i].length; ++j) {
                    n2 = this.indexOf(n, this.contexts[i][j]);
                    double d5 = 1.0;
                    if (this.values != null) {
                        d5 = this.values[i][j];
                    }
                    if (d5 == 0.0) continue;
                    int n5 = n2;
                    dArray3[n5] = dArray3[n5] + d5 * this.probModel[i][n] * (double)this.numTimesEventsSeen[i];
                }
            }
        }
        double[] dArray4 = new double[this.domainDimension];
        for (n = 0; n < this.numOutcomes * this.numFeatures; ++n) {
            dArray4[n] = dArray3[n] - this.empiricalCount[n];
        }
        this.gradient = dArray4;
        this.lastX = (double[])dArray.clone();
    }

    private boolean checkLastX(double[] dArray) {
        if (this.lastX == null) {
            return false;
        }
        for (int i = 0; i < dArray.length; ++i) {
            if (this.lastX[i] == dArray[i]) continue;
            return false;
        }
        return true;
    }

    private int indexOf(int n, int n2) {
        return n * this.numFeatures + n2;
    }

    private void initProbModel() {
        for (int i = 0; i < this.probModel.length; ++i) {
            Arrays.fill(this.probModel[i], 1.0);
        }
    }

    private void initEmpCount() {
        int n;
        int n2;
        this.empiricalCount = new double[this.numOutcomes * this.numFeatures];
        this.outcomePatterns = new int[this.predLabels.length][];
        for (n2 = 0; n2 < this.numContexts; ++n2) {
            for (int i = 0; i < this.contexts[n2].length; ++i) {
                n = this.indexOf(this.outcomeList[n2], this.contexts[n2][i]);
                if (this.values != null) {
                    int n3 = n;
                    this.empiricalCount[n3] = this.empiricalCount[n3] + (double)(this.values[n2][i] * (float)this.numTimesEventsSeen[n2]);
                    continue;
                }
                int n4 = n;
                this.empiricalCount[n4] = this.empiricalCount[n4] + 1.0 * (double)this.numTimesEventsSeen[n2];
            }
        }
        for (n2 = 0; n2 < this.outcomePatterns.length; ++n2) {
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (n = 0; n < this.outcomeLabels.length; ++n) {
                int n5 = n2 + this.predLabels.length * n;
                if (!(this.empiricalCount[n5] > 0.0)) continue;
                arrayList.add(n);
            }
            this.outcomePatterns[n2] = new int[arrayList.size()];
            for (n = 0; n < arrayList.size(); ++n) {
                this.outcomePatterns[n2][n] = (Integer)arrayList.get(n);
            }
        }
    }
}

