/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.maxent;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import opennlp.maxent.GISModel;
import opennlp.model.Context;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.MutableContext;
import opennlp.model.Prior;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class GISTrainer {
    private static final Logger LOG = LoggerFactory.getLogger(GISTrainer.class);
    private boolean useSimpleSmoothing = false;
    private boolean useGaussianSmoothing = false;
    private double sigma = 2.0;
    private double smoothingObservation = 0.1;
    private int numUniqueEvents;
    private int[][] trainingDataFeatNameIndices;
    private float[][] trainingDataFeatValues;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private int[] predicateCounts;
    private int cutoff;
    private String[] labels;
    private String[] featNames;
    private MutableContext[] observedExpects;
    private MutableContext[] params;
    private MutableContext[][] modelExpects;
    private Prior prior;
    private static final double LLThreshold = 1.0E-4;
    private EvalParameters evalParams;

    GISTrainer() {
    }

    public void setSmoothing(boolean smooth) {
        this.useSimpleSmoothing = smooth;
    }

    public void setSmoothingObservation(double timesSeen) {
        this.smoothingObservation = timesSeen;
    }

    public void setGaussianSigma(double sigmaValue) {
        this.useGaussianSmoothing = true;
        this.sigma = sigmaValue;
    }

    public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) {
        if (threads <= 0) {
            threads = 1;
        }
        this.modelExpects = new MutableContext[threads][];
        LOG.debug("Incorporating indexed data for training...");
        this.trainingDataFeatNameIndices = di.getContexts();
        this.trainingDataFeatValues = di.getValues();
        this.cutoff = cutoff;
        this.predicateCounts = di.getPredCounts();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numUniqueEvents = this.trainingDataFeatNameIndices.length;
        this.prior = modelPrior;
        this.labels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.featNames = di.getPredLabels();
        this.prior.setLabels(this.labels, this.featNames);
        double correctionConstant = 0.0;
        for (int ci = 0; ci < this.trainingDataFeatNameIndices.length; ++ci) {
            if (this.trainingDataFeatValues == null || this.trainingDataFeatValues[ci] == null) {
                if (!((double)this.trainingDataFeatNameIndices[ci].length > correctionConstant)) continue;
                correctionConstant = this.trainingDataFeatNameIndices[ci].length;
                continue;
            }
            float cl = this.trainingDataFeatValues[ci][0];
            for (int vi = 1; vi < this.trainingDataFeatValues[ci].length; ++vi) {
                cl += this.trainingDataFeatValues[ci][vi];
            }
            if (!((double)cl > correctionConstant)) continue;
            correctionConstant = cl;
        }
        LOG.debug("Number of Event Tokens: " + this.numUniqueEvents);
        LOG.debug("Number of Outcomes: " + this.labels.length);
        LOG.debug("Number of Predicates: " + this.featNames.length);
        float[][] featCount = new float[this.featNames.length][this.labels.length];
        for (int ti = 0; ti < this.numUniqueEvents; ++ti) {
            for (int j = 0; j < this.trainingDataFeatNameIndices[ti].length; ++j) {
                if (this.trainingDataFeatValues != null && this.trainingDataFeatValues[ti] != null) {
                    float[] fArray = featCount[this.trainingDataFeatNameIndices[ti][j]];
                    int n = this.outcomeList[ti];
                    fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti] * this.trainingDataFeatValues[ti][j];
                    continue;
                }
                float[] fArray = featCount[this.trainingDataFeatNameIndices[ti][j]];
                int n = this.outcomeList[ti];
                fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti];
            }
        }
        double smoothingObservation = this.smoothingObservation;
        this.params = new MutableContext[this.featNames.length];
        for (int i = 0; i < this.modelExpects.length; ++i) {
            this.modelExpects[i] = new MutableContext[this.featNames.length];
        }
        this.observedExpects = new MutableContext[this.featNames.length];
        this.evalParams = new EvalParameters((Context[])this.params, 0.0, 1.0, this.labels.length);
        int[] activeOutcomes = new int[this.labels.length];
        int[] labelPattern = new int[this.labels.length];
        for (int oi = 0; oi < this.labels.length; ++oi) {
            labelPattern[oi] = oi;
        }
        for (int pi = 0; pi < this.featNames.length; ++pi) {
            int aoi;
            int[] outcomePattern;
            int numActiveOutcomes = 0;
            if (this.useSimpleSmoothing) {
                numActiveOutcomes = this.labels.length;
                outcomePattern = labelPattern;
            } else {
                for (int oi = 0; oi < this.labels.length; ++oi) {
                    if (!(featCount[pi][oi] > 0.0f) || this.predicateCounts[pi] < cutoff) continue;
                    activeOutcomes[numActiveOutcomes] = oi;
                    ++numActiveOutcomes;
                }
                if (numActiveOutcomes == this.labels.length) {
                    outcomePattern = labelPattern;
                } else {
                    outcomePattern = new int[numActiveOutcomes];
                    for (aoi = 0; aoi < numActiveOutcomes; ++aoi) {
                        outcomePattern[aoi] = activeOutcomes[aoi];
                    }
                }
            }
            this.params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            for (int i = 0; i < this.modelExpects.length; ++i) {
                this.modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            }
            this.observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            for (aoi = 0; aoi < numActiveOutcomes; ++aoi) {
                int oi = outcomePattern[aoi];
                this.params[pi].setParameter(aoi, 0.0);
                for (MutableContext[] modelExpect : this.modelExpects) {
                    modelExpect[pi].setParameter(aoi, 0.0);
                }
                if (featCount[pi][oi] > 0.0f) {
                    this.observedExpects[pi].setParameter(aoi, (double)featCount[pi][oi]);
                    continue;
                }
                if (!this.useSimpleSmoothing) continue;
                this.observedExpects[pi].setParameter(aoi, smoothingObservation);
            }
        }
        LOG.debug("Computing model parameters in " + threads + " threads...");
        this.findParameters(iterations, correctionConstant);
        return new GISModel((Context[])this.params, this.featNames, this.labels, 1, this.evalParams.getCorrectionParam());
    }

    private void findParameters(int iterations, double correctionConstant) {
        LOG.info("Performing max " + iterations + " iterations.");
        double prevLL = 0.0;
        for (int i = 1; i <= iterations; ++i) {
            LOG.debug("Iteration " + i);
            double currLL = this.nextIteration(correctionConstant);
            if (i > 1) {
                if (prevLL > currLL) {
                    LOG.error("Model Diverging: log likelihood decreased");
                    break;
                }
                if (currLL - prevLL < 1.0E-4) break;
            }
            prevLL = currLL;
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.trainingDataFeatNameIndices = null;
    }

    private double gaussianUpdate(int predicate, int oid, double correctionConstant) {
        double param = this.params[predicate].getParameters()[oid];
        double x0 = 0.0;
        double modelValue = this.modelExpects[0][predicate].getParameters()[oid];
        double observedValue = this.observedExpects[predicate].getParameters()[oid];
        for (int i = 0; i < 50; ++i) {
            double tmp = modelValue * Math.exp(correctionConstant * x0);
            double f = tmp + (param + x0) / this.sigma - observedValue;
            double fp = tmp * correctionConstant + 1.0 / this.sigma;
            if (fp == 0.0) break;
            double x = x0 - f / fp;
            if (Math.abs(x - x0) < 1.0E-6) {
                x0 = x;
                break;
            }
            x0 = x;
        }
        return x0;
    }

    private double nextIteration(double correctionConstant) {
        int pi;
        double loglikelihood = 0.0;
        int numEvents = 0;
        int numCorrect = 0;
        int numberOfThreads = this.modelExpects.length;
        ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
        int taskSize = this.numUniqueEvents / numberOfThreads;
        int leftOver = this.numUniqueEvents % numberOfThreads;
        ArrayList<Future<ModelExpactationComputeTask>> futures = new ArrayList<Future<ModelExpactationComputeTask>>();
        for (int i = 0; i < numberOfThreads; ++i) {
            int n = i == numberOfThreads - 1 ? taskSize + leftOver : taskSize;
            futures.add(executor.submit(new ModelExpactationComputeTask(i, i * taskSize, n)));
        }
        for (Future future : futures) {
            ModelExpactationComputeTask finishedTask;
            try {
                finishedTask = (ModelExpactationComputeTask)future.get();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
                throw new IllegalStateException("Interruption is not supported!", e);
            }
            catch (ExecutionException e) {
                throw new RuntimeException("Exception during training: " + e.getMessage(), e);
            }
            numEvents += finishedTask.getNumEvents();
            numCorrect += finishedTask.getNumCorrect();
            loglikelihood += finishedTask.getLoglikelihood();
        }
        executor.shutdown();
        for (pi = 0; pi < this.featNames.length; ++pi) {
            int[] nArray = this.params[pi].getOutcomes();
            for (int aoi = 0; aoi < nArray.length; ++aoi) {
                for (int i = 1; i < this.modelExpects.length; ++i) {
                    this.modelExpects[0][pi].updateParameter(aoi, this.modelExpects[i][pi].getParameters()[aoi]);
                }
            }
        }
        for (pi = 0; pi < this.featNames.length; ++pi) {
            double[] dArray = this.observedExpects[pi].getParameters();
            double[] model = this.modelExpects[0][pi].getParameters();
            int[] activeOutcomes = this.params[pi].getOutcomes();
            for (int aoi = 0; aoi < activeOutcomes.length; ++aoi) {
                if (this.useGaussianSmoothing) {
                    this.params[pi].updateParameter(aoi, this.gaussianUpdate(pi, aoi, correctionConstant));
                } else {
                    if (model[aoi] == 0.0) {
                        LOG.error("Model expects == 0 for " + this.featNames[pi] + " " + this.labels[aoi]);
                    }
                    this.params[pi].updateParameter(aoi, (Math.log(dArray[aoi]) - Math.log(model[aoi])) / correctionConstant);
                }
                for (MutableContext[] modelExpect : this.modelExpects) {
                    modelExpect[pi].setParameter(aoi, 0.0);
                }
            }
        }
        LOG.debug("loglikelihood = " + loglikelihood + "\taccuracy: " + (double)numCorrect / (double)numEvents * 100.0);
        return loglikelihood;
    }

    private class ModelExpactationComputeTask
    implements Callable<ModelExpactationComputeTask> {
        private final int startIndex;
        private final int length;
        private double loglikelihood = 0.0;
        private int numEvents = 0;
        private int numCorrect = 0;
        private final int threadIndex;

        ModelExpactationComputeTask(int threadIndex, int startIndex, int length) {
            this.startIndex = startIndex;
            this.length = length;
            this.threadIndex = threadIndex;
        }

        @Override
        public ModelExpactationComputeTask call() {
            double[] modelDistribution = new double[GISTrainer.this.labels.length];
            for (int ei = this.startIndex; ei < this.startIndex + this.length; ++ei) {
                GISTrainer.this.prior.logPrior(modelDistribution, GISTrainer.this.trainingDataFeatNameIndices[ei], GISTrainer.this.trainingDataFeatValues[ei]);
                GISModel.eval((int[])GISTrainer.this.trainingDataFeatNameIndices[ei], (float[])GISTrainer.this.trainingDataFeatValues[ei], (double[])modelDistribution, (EvalParameters)GISTrainer.this.evalParams);
                for (int j = 0; j < GISTrainer.this.trainingDataFeatNameIndices[ei].length; ++j) {
                    int pi = GISTrainer.this.trainingDataFeatNameIndices[ei][j];
                    if (GISTrainer.this.predicateCounts[pi] < GISTrainer.this.cutoff) continue;
                    int[] activeOutcomes = GISTrainer.this.modelExpects[this.threadIndex][pi].getOutcomes();
                    for (int aoi = 0; aoi < activeOutcomes.length; ++aoi) {
                        int oi = activeOutcomes[aoi];
                        if (GISTrainer.this.trainingDataFeatValues != null && GISTrainer.this.trainingDataFeatValues[ei] != null) {
                            GISTrainer.this.modelExpects[this.threadIndex][pi].updateParameter(aoi, modelDistribution[oi] * (double)GISTrainer.this.trainingDataFeatValues[ei][j] * (double)GISTrainer.this.numTimesEventsSeen[ei]);
                            continue;
                        }
                        GISTrainer.this.modelExpects[this.threadIndex][pi].updateParameter(aoi, modelDistribution[oi] * (double)GISTrainer.this.numTimesEventsSeen[ei]);
                    }
                }
                this.loglikelihood += Math.log(modelDistribution[GISTrainer.this.outcomeList[ei]]) * (double)GISTrainer.this.numTimesEventsSeen[ei];
                this.numEvents += GISTrainer.this.numTimesEventsSeen[ei];
                int maxIndex = 0;
                for (int labelIndex = 1; labelIndex < GISTrainer.this.labels.length; ++labelIndex) {
                    if (!(modelDistribution[labelIndex] > modelDistribution[maxIndex])) continue;
                    maxIndex = labelIndex;
                }
                if (maxIndex != GISTrainer.this.outcomeList[ei]) continue;
                this.numCorrect += GISTrainer.this.numTimesEventsSeen[ei];
            }
            return this;
        }

        synchronized int getNumEvents() {
            return this.numEvents;
        }

        synchronized int getNumCorrect() {
            return this.numCorrect;
        }

        synchronized double getLoglikelihood() {
            return this.loglikelihood;
        }
    }
}

