/*
 * Decompiled with CFR 0.152.
 */
package edu.nyu.jet.chunk;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import edu.nyu.jet.chunk.GISTrainer;
import opennlp.maxent.GISModel;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.MutableContext;
import opennlp.model.Prior;

public class LBFGSTrainer
extends GISTrainer {
    private double c = 1.0;
    private boolean useSimpleSmoothing = false;
    private boolean useSlackParameter = false;
    private boolean useGaussianSmoothing = false;
    private double sigma = 2.0;
    private double _smoothingObservation = 0.1;
    private boolean printMessages = false;
    private int numUniqueEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private int[] predicateCounts;
    private int cutoff;
    private String[] outcomeLabels;
    private String[] predLabels;
    private MutableContext[] observedExpects;
    private MutableContext[] params;
    private MutableContext[] modelExpects;
    private Prior prior;
    private double cfObservedExpect;
    private double CFMOD;
    private final double NEAR_ZERO = 0.01;
    private final double LLThreshold = 1.0E-4;

    public LBFGSTrainer() {
    }

    public LBFGSTrainer(double c) {
        this.c = c;
    }

    public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff) {
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = di.getContexts();
        this.values = di.getValues();
        this.cutoff = cutoff;
        this.predicateCounts = di.getPredCounts();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numUniqueEvents = this.contexts.length;
        this.prior = modelPrior;
        int correctionConstant = 1;
        for (int ci = 0; ci < this.contexts.length; ++ci) {
            if (this.values == null || this.values[ci] == null) {
                if (this.contexts[ci].length <= correctionConstant) continue;
                correctionConstant = this.contexts[ci].length;
                continue;
            }
            float cl = this.values[ci][0];
            for (int vi = 1; vi < this.values[ci].length; ++vi) {
                cl += this.values[ci][vi];
            }
            if (!(cl > (float)correctionConstant)) continue;
            correctionConstant = (int)Math.ceil(cl);
        }
        this.display("done.\n");
        this.outcomeLabels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.numOutcomes = this.outcomeLabels.length;
        this.predLabels = di.getPredLabels();
        this.prior.setLabels(this.outcomeLabels, this.predLabels);
        this.numPreds = this.predLabels.length;
        this.display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        float[][] predCount = new float[this.numPreds][this.numOutcomes];
        for (int ti = 0; ti < this.numUniqueEvents; ++ti) {
            for (int j = 0; j < this.contexts[ti].length; ++j) {
                if (this.values != null && this.values[ti] != null) {
                    float[] fArray = predCount[this.contexts[ti][j]];
                    int n = this.outcomeList[ti];
                    fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti] * this.values[ti][j];
                    continue;
                }
                float[] fArray = predCount[this.contexts[ti][j]];
                int n = this.outcomeList[ti];
                fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti];
            }
        }
        di = null;
        double smoothingObservation = this._smoothingObservation;
        this.params = new MutableContext[this.numPreds];
        this.modelExpects = new MutableContext[this.numPreds];
        this.observedExpects = new MutableContext[this.numPreds];
        this.evalParams = new EvalParameters(this.params, 0.0, 1.0, this.numOutcomes);
        int[] activeOutcomes = new int[this.numOutcomes];
        int[] allOutcomesPattern = new int[this.numOutcomes];
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            allOutcomesPattern[oi] = oi;
        }
        int numActiveOutcomes = 0;
        for (int pi = 0; pi < this.numPreds; ++pi) {
            int aoi;
            int[] outcomePattern;
            numActiveOutcomes = 0;
            if (this.useSimpleSmoothing) {
                numActiveOutcomes = this.numOutcomes;
                outcomePattern = allOutcomesPattern;
            } else {
                for (int oi = 0; oi < this.numOutcomes; ++oi) {
                    if (!(predCount[pi][oi] > 0.0f) || this.predicateCounts[pi] < cutoff) continue;
                    activeOutcomes[numActiveOutcomes] = oi;
                    ++numActiveOutcomes;
                }
                if (numActiveOutcomes == this.numOutcomes) {
                    outcomePattern = allOutcomesPattern;
                } else {
                    outcomePattern = new int[numActiveOutcomes];
                    for (aoi = 0; aoi < numActiveOutcomes; ++aoi) {
                        outcomePattern[aoi] = activeOutcomes[aoi];
                    }
                }
            }
            this.params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            this.modelExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            this.observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            for (aoi = 0; aoi < numActiveOutcomes; ++aoi) {
                int oi = outcomePattern[aoi];
                this.params[pi].setParameter(aoi, 0.0);
                this.modelExpects[pi].setParameter(aoi, 0.0);
                if (predCount[pi][oi] > 0.0f) {
                    this.observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
                    continue;
                }
                if (!this.useSimpleSmoothing) continue;
                this.observedExpects[pi].setParameter(aoi, smoothingObservation);
            }
        }
        if (this.useSlackParameter) {
            int cfvalSum = 0;
            for (int ti = 0; ti < this.numUniqueEvents; ++ti) {
                for (int j = 0; j < this.contexts[ti].length; ++j) {
                    int pi = this.contexts[ti][j];
                    if (this.modelExpects[pi].contains(this.outcomeList[ti])) continue;
                    cfvalSum += this.numTimesEventsSeen[ti];
                }
                cfvalSum += (correctionConstant - this.contexts[ti].length) * this.numTimesEventsSeen[ti];
            }
            this.cfObservedExpect = cfvalSum == 0 ? Math.log(0.01) : Math.log(cfvalSum);
        }
        predCount = null;
        this.display("...done.\n");
        this.modelDistribution = new double[this.numOutcomes];
        this.numfeats = new int[this.numOutcomes];
        this.display("Computing model parameters...\n");
        this.findParameters();
        return new GISModel(this.params, this.predLabels, this.outcomeLabels, 1, this.evalParams.getCorrectionParam());
    }

    private void findParameters() {
        MaxEntOptimization optimizable = new MaxEntOptimization();
        LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(optimizable);
        boolean converged = false;
        try {
            converged = optimizer.optimize();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (converged) {
            System.out.println("L-BFGS converged.");
        } else {
            System.out.println("L-BFGS stopped before convergence.");
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.contexts = null;
    }

    private void display(String s) {
        if (this.printMessages) {
            System.out.print(s);
        }
    }

    class MaxEntOptimization
    implements Optimizable.ByGradientValue {
        int numOfParameters = this.countParameterLength();
        int[][] indexMapping;

        public MaxEntOptimization() {
            this.buildIndexMapping();
        }

        private void buildIndexMapping() {
            this.indexMapping = new int[this.numOfParameters][2];
            int i = 0;
            for (int j = 0; j < LBFGSTrainer.this.params.length; ++j) {
                double[] parameters = LBFGSTrainer.this.params[j].getParameters();
                int k = 0;
                while (k < parameters.length) {
                    this.indexMapping[i][0] = j;
                    this.indexMapping[i][1] = k++;
                    ++i;
                }
            }
        }

        private double innerProduct(double[] a, double[] b) {
            double result = 0.0;
            for (int i = 0; i < a.length; ++i) {
                result = a[i] * b[i];
            }
            return result;
        }

        private int countParameterLength() {
            int numOfParameters = 0;
            for (MutableContext param : LBFGSTrainer.this.params) {
                numOfParameters += param.getParameters().length;
            }
            return numOfParameters;
        }

        public double getValue() {
            double loglikelihood = 0.0;
            for (int pi = 0; pi < LBFGSTrainer.this.numPreds; ++pi) {
                if (LBFGSTrainer.this.predicateCounts[pi] < LBFGSTrainer.this.cutoff) continue;
                double[] paramsForPi = LBFGSTrainer.this.params[pi].getParameters();
                loglikelihood += this.innerProduct(paramsForPi, paramsForPi);
            }
            loglikelihood = -LBFGSTrainer.this.c / 2.0 * loglikelihood;
            for (int ei = 0; ei < LBFGSTrainer.this.numUniqueEvents; ++ei) {
                if (LBFGSTrainer.this.values != null) {
                    LBFGSTrainer.this.prior.logPrior(LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.values[ei]);
                    GISModel.eval(LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.values[ei], LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.evalParams);
                } else {
                    LBFGSTrainer.this.prior.logPrior(LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.contexts[ei]);
                    GISModel.eval(LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.evalParams);
                }
                loglikelihood += Math.log(LBFGSTrainer.this.modelDistribution[LBFGSTrainer.this.outcomeList[ei]]) * (double)LBFGSTrainer.this.numTimesEventsSeen[ei];
            }
            LBFGSTrainer.this.display(".");
            return loglikelihood;
        }

        public void getValueGradient(double[] gradient) {
            int i = 0;
            for (MutableContext param : LBFGSTrainer.this.params) {
                double[] labelParameters = param.getParameters();
                for (int j = 0; j < labelParameters.length; ++j) {
                    gradient[i] = -LBFGSTrainer.this.c * labelParameters[j];
                    ++i;
                }
            }
            for (int ei = 0; ei < LBFGSTrainer.this.numUniqueEvents; ++ei) {
                if (LBFGSTrainer.this.values != null) {
                    LBFGSTrainer.this.prior.logPrior(LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.values[ei]);
                    GISModel.eval(LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.values[ei], LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.evalParams);
                } else {
                    LBFGSTrainer.this.prior.logPrior(LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.contexts[ei]);
                    GISModel.eval(LBFGSTrainer.this.contexts[ei], LBFGSTrainer.this.modelDistribution, LBFGSTrainer.this.evalParams);
                }
                for (int j = 0; j < LBFGSTrainer.this.contexts[ei].length; ++j) {
                    int pi = LBFGSTrainer.this.contexts[ei][j];
                    if (LBFGSTrainer.this.predicateCounts[pi] < LBFGSTrainer.this.cutoff) continue;
                    int[] activeOutcomes = LBFGSTrainer.this.modelExpects[pi].getOutcomes();
                    for (int aoi = 0; aoi < activeOutcomes.length; ++aoi) {
                        int oi = activeOutcomes[aoi];
                        if (LBFGSTrainer.this.values != null && LBFGSTrainer.this.values[ei] != null) {
                            LBFGSTrainer.this.modelExpects[pi].updateParameter(aoi, LBFGSTrainer.this.modelDistribution[oi] * (double)LBFGSTrainer.this.values[ei][j] * (double)LBFGSTrainer.this.numTimesEventsSeen[ei]);
                            continue;
                        }
                        LBFGSTrainer.this.modelExpects[pi].updateParameter(aoi, LBFGSTrainer.this.modelDistribution[oi] * (double)LBFGSTrainer.this.numTimesEventsSeen[ei]);
                    }
                }
            }
            i = 0;
            for (int pi = 0; pi < LBFGSTrainer.this.params.length; ++pi) {
                double[] observedExpectsForParam;
                double[] modelExpectsForParam = LBFGSTrainer.this.modelExpects[pi].getParameters();
                if (modelExpectsForParam.length != (observedExpectsForParam = LBFGSTrainer.this.observedExpects[pi].getParameters()).length) {
                    System.err.println("Length of modelExpects and observedExpects should equal.");
                    return;
                }
                for (int j = 0; j < modelExpectsForParam.length; ++j) {
                    gradient[i] = gradient[i] + observedExpectsForParam[j] - modelExpectsForParam[j];
                    ++i;
                }
            }
            for (int j = 0; j < LBFGSTrainer.this.modelExpects.length; ++j) {
                int piLen = LBFGSTrainer.this.modelExpects[j].getParameters().length;
                for (int k = 0; k < piLen; ++k) {
                    LBFGSTrainer.this.modelExpects[j].setParameter(k, 0.0);
                }
            }
        }

        public int getNumParameters() {
            return this.numOfParameters;
        }

        public double getParameter(int i) {
            int j = this.indexMapping[i][0];
            int k = this.indexMapping[i][1];
            return LBFGSTrainer.this.params[j].getParameters()[k];
        }

        public void getParameters(double[] buffer) {
            int i = 0;
            for (MutableContext param : LBFGSTrainer.this.params) {
                double[] parameters = param.getParameters();
                for (int j = 0; j < parameters.length; ++j) {
                    buffer[i] = parameters[j];
                    ++i;
                }
            }
        }

        public void setParameter(int i, double r) {
            int j = this.indexMapping[i][0];
            int k = this.indexMapping[i][1];
            ((LBFGSTrainer)LBFGSTrainer.this).params[j].getParameters()[k] = r;
        }

        public void setParameters(double[] newParameters) {
            int i = 0;
            for (MutableContext param : LBFGSTrainer.this.params) {
                double[] parameters = param.getParameters();
                for (int j = 0; j < parameters.length; ++j) {
                    parameters[j] = newParameters[i];
                    ++i;
                }
            }
        }
    }
}

