/*
 * Decompiled with CFR 0.152.
 */
package nak.maxent;

import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import nak.maxent.GISModel;
import nak.model.DataIndexer;
import nak.model.EvalParameters;
import nak.model.EventStream;
import nak.model.MutableContext;
import nak.model.OnePassDataIndexer;
import nak.model.Prior;
import nak.model.UniformPrior;

class GISTrainer {
    private boolean useSimpleSmoothing = false;
    private boolean useGaussianSmoothing = false;
    private double sigma = 2.0;
    private double _smoothingObservation = 0.1;
    private final boolean printMessages;
    private int numUniqueEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private int[] predicateCounts;
    private int cutoff;
    private String[] outcomeLabels;
    private String[] predLabels;
    private MutableContext[] observedExpects;
    private MutableContext[] params;
    private MutableContext[][] modelExpects;
    private Prior prior;
    private static final double LLThreshold = 1.0E-4;
    private EvalParameters evalParams;

    GISTrainer() {
        this.printMessages = false;
    }

    GISTrainer(boolean bl) {
        this.printMessages = bl;
    }

    public void setSmoothing(boolean bl) {
        this.useSimpleSmoothing = bl;
    }

    public void setSmoothingObservation(double d) {
        this._smoothingObservation = d;
    }

    public void setGaussianSigma(double d) {
        this.useGaussianSmoothing = true;
        this.sigma = d;
    }

    public GISModel trainModel(EventStream eventStream, int n, int n2) throws IOException {
        return this.trainModel(n, new OnePassDataIndexer(eventStream, n2), n2);
    }

    public GISModel trainModel(int n, DataIndexer dataIndexer, int n2) {
        return this.trainModel(n, dataIndexer, new UniformPrior(), n2, 1);
    }

    public GISModel trainModel(int n, DataIndexer dataIndexer, Prior prior, int n2, int n3) {
        int n4;
        int n5;
        if (n3 <= 0) {
            throw new IllegalArgumentException("threads must be at least one or greater but is " + n3 + "!");
        }
        this.modelExpects = new MutableContext[n3][];
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = dataIndexer.getContexts();
        this.values = dataIndexer.getValues();
        this.cutoff = n2;
        this.predicateCounts = dataIndexer.getPredCounts();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numUniqueEvents = this.contexts.length;
        this.prior = prior;
        double d = 0.0;
        for (int i = 0; i < this.contexts.length; ++i) {
            if (this.values == null || this.values[i] == null) {
                if (!((double)this.contexts[i].length > d)) continue;
                d = this.contexts[i].length;
                continue;
            }
            float f = this.values[i][0];
            for (n5 = 1; n5 < this.values[i].length; ++n5) {
                f += this.values[i][n5];
            }
            if (!((double)f > d)) continue;
            d = f;
        }
        this.display("done.\n");
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.numOutcomes = this.outcomeLabels.length;
        this.predLabels = dataIndexer.getPredLabels();
        this.prior.setLabels(this.outcomeLabels, this.predLabels);
        this.numPreds = this.predLabels.length;
        this.display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        float[][] fArray = new float[this.numPreds][this.numOutcomes];
        for (int i = 0; i < this.numUniqueEvents; ++i) {
            for (n5 = 0; n5 < this.contexts[i].length; ++n5) {
                if (this.values != null && this.values[i] != null) {
                    float[] fArray2 = fArray[this.contexts[i][n5]];
                    int n6 = this.outcomeList[i];
                    fArray2[n6] = fArray2[n6] + (float)this.numTimesEventsSeen[i] * this.values[i][n5];
                    continue;
                }
                float[] fArray3 = fArray[this.contexts[i][n5]];
                int n7 = this.outcomeList[i];
                fArray3[n7] = fArray3[n7] + (float)this.numTimesEventsSeen[i];
            }
        }
        dataIndexer = null;
        double d2 = this._smoothingObservation;
        this.params = new MutableContext[this.numPreds];
        for (int i = 0; i < this.modelExpects.length; ++i) {
            this.modelExpects[i] = new MutableContext[this.numPreds];
        }
        this.observedExpects = new MutableContext[this.numPreds];
        this.evalParams = new EvalParameters(this.params, 0.0, 1.0, this.numOutcomes);
        int[] nArray = new int[this.numOutcomes];
        int[] nArray2 = new int[this.numOutcomes];
        for (n4 = 0; n4 < this.numOutcomes; ++n4) {
            nArray2[n4] = n4;
        }
        n4 = 0;
        for (int i = 0; i < this.numPreds; ++i) {
            int n8;
            int[] nArray3;
            n4 = 0;
            if (this.useSimpleSmoothing) {
                n4 = this.numOutcomes;
                nArray3 = nArray2;
            } else {
                for (n8 = 0; n8 < this.numOutcomes; ++n8) {
                    if (!(fArray[i][n8] > 0.0f) || this.predicateCounts[i] < n2) continue;
                    nArray[n4] = n8;
                    ++n4;
                }
                if (n4 == this.numOutcomes) {
                    nArray3 = nArray2;
                } else {
                    nArray3 = new int[n4];
                    for (n8 = 0; n8 < n4; ++n8) {
                        nArray3[n8] = nArray[n8];
                    }
                }
            }
            this.params[i] = new MutableContext(nArray3, new double[n4]);
            for (n8 = 0; n8 < this.modelExpects.length; ++n8) {
                this.modelExpects[n8][i] = new MutableContext(nArray3, new double[n4]);
            }
            this.observedExpects[i] = new MutableContext(nArray3, new double[n4]);
            for (n8 = 0; n8 < n4; ++n8) {
                int n9 = nArray3[n8];
                this.params[i].setParameter(n8, 0.0);
                for (MutableContext[] mutableContextArray : this.modelExpects) {
                    mutableContextArray[i].setParameter(n8, 0.0);
                }
                if (fArray[i][n9] > 0.0f) {
                    this.observedExpects[i].setParameter(n8, fArray[i][n9]);
                    continue;
                }
                if (!this.useSimpleSmoothing) continue;
                this.observedExpects[i].setParameter(n8, d2);
            }
        }
        fArray = null;
        this.display("...done.\n");
        if (n3 == 1) {
            this.display("Computing model parameters ...\n");
        } else {
            this.display("Computing model parameters in " + n3 + " threads...\n");
        }
        this.findParameters(n, d);
        return new GISModel(this.params, this.predLabels, this.outcomeLabels, 1, this.evalParams.getCorrectionParam());
    }

    private void findParameters(int n, double d) {
        double d2 = 0.0;
        double d3 = 0.0;
        this.display("Performing " + n + " iterations.\n");
        for (int i = 1; i <= n; ++i) {
            if (i < 10) {
                this.display("  " + i + ":  ");
            } else if (i < 100) {
                this.display(" " + i + ":  ");
            } else {
                this.display(i + ":  ");
            }
            d3 = this.nextIteration(d);
            if (i > 1) {
                if (d2 > d3) {
                    System.err.println("Model Diverging: loglikelihood decreased");
                    break;
                }
                if (d3 - d2 < 1.0E-4) break;
            }
            d2 = d3;
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.contexts = null;
    }

    private double gaussianUpdate(int n, int n2, int n3, double d) {
        double d2 = this.params[n].getParameters()[n2];
        double d3 = 0.0;
        double d4 = this.modelExpects[0][n].getParameters()[n2];
        double d5 = this.observedExpects[n].getParameters()[n2];
        for (int i = 0; i < 50; ++i) {
            double d6 = d4 * Math.exp(d * d3);
            double d7 = d6 + (d2 + d3) / this.sigma - d5;
            double d8 = d6 * d + 1.0 / this.sigma;
            if (d8 == 0.0) break;
            double d9 = d3 - d7 / d8;
            if (Math.abs(d9 - d3) < 1.0E-6) {
                d3 = d9;
                break;
            }
            d3 = d9;
        }
        return d3;
    }

    private double nextIteration(double d) {
        int n;
        double d2 = 0.0;
        int n2 = 0;
        int n3 = 0;
        int n4 = this.modelExpects.length;
        ExecutorService executorService = Executors.newFixedThreadPool(n4);
        int n5 = this.numUniqueEvents / n4;
        int n6 = this.numUniqueEvents % n4;
        ArrayList<Future<ModelExpactationComputeTask>> arrayList = new ArrayList<Future<ModelExpactationComputeTask>>();
        for (int i = 0; i < n4; ++i) {
            if (i != n4 - 1) {
                arrayList.add(executorService.submit(new ModelExpactationComputeTask(i, i * n5, n5)));
                continue;
            }
            arrayList.add(executorService.submit(new ModelExpactationComputeTask(i, i * n5, n5 + n6)));
        }
        for (Future object2 : arrayList) {
            ModelExpactationComputeTask modelExpactationComputeTask = null;
            try {
                modelExpactationComputeTask = (ModelExpactationComputeTask)object2.get();
            }
            catch (InterruptedException interruptedException) {
                interruptedException.printStackTrace();
                throw new IllegalStateException("Interruption is not supported!", interruptedException);
            }
            catch (ExecutionException executionException) {
                throw new RuntimeException("Exception during training: " + executionException.getMessage(), executionException);
            }
            n2 += modelExpactationComputeTask.getNumEvents();
            n3 += modelExpactationComputeTask.getNumCorrect();
            d2 += modelExpactationComputeTask.getLoglikelihood();
        }
        executorService.shutdown();
        this.display(".");
        for (n = 0; n < this.numPreds; ++n) {
            int[] nArray = this.params[n].getOutcomes();
            for (int i = 0; i < nArray.length; ++i) {
                for (int j = 1; j < this.modelExpects.length; ++j) {
                    this.modelExpects[0][n].updateParameter(i, this.modelExpects[j][n].getParameters()[i]);
                }
            }
        }
        this.display(".");
        for (n = 0; n < this.numPreds; ++n) {
            double[] dArray = this.observedExpects[n].getParameters();
            double[] dArray2 = this.modelExpects[0][n].getParameters();
            int[] nArray = this.params[n].getOutcomes();
            for (int i = 0; i < nArray.length; ++i) {
                if (this.useGaussianSmoothing) {
                    this.params[n].updateParameter(i, this.gaussianUpdate(n, i, n2, d));
                } else {
                    if (dArray2[i] == 0.0) {
                        System.err.println("Model expects == 0 for " + this.predLabels[n] + " " + this.outcomeLabels[i]);
                    }
                    this.params[n].updateParameter(i, (Math.log(dArray[i]) - Math.log(dArray2[i])) / d);
                }
                for (MutableContext[] mutableContextArray : this.modelExpects) {
                    mutableContextArray[n].setParameter(i, 0.0);
                }
            }
        }
        this.display(". loglikelihood=" + d2 + "\t" + (double)n3 / (double)n2 + "\n");
        return d2;
    }

    private void display(String string) {
        if (this.printMessages) {
            System.out.print(string);
        }
    }

    private class ModelExpactationComputeTask
    implements Callable<ModelExpactationComputeTask> {
        private final int startIndex;
        private final int length;
        private double loglikelihood = 0.0;
        private int numEvents = 0;
        private int numCorrect = 0;
        private final int threadIndex;

        ModelExpactationComputeTask(int n, int n2, int n3) {
            this.startIndex = n2;
            this.length = n3;
            this.threadIndex = n;
        }

        @Override
        public ModelExpactationComputeTask call() {
            double[] dArray = new double[GISTrainer.this.numOutcomes];
            for (int i = this.startIndex; i < this.startIndex + this.length; ++i) {
                int n;
                int n2;
                if (GISTrainer.this.values != null) {
                    GISTrainer.this.prior.logPrior(dArray, GISTrainer.this.contexts[i], GISTrainer.this.values[i]);
                    GISModel.eval(GISTrainer.this.contexts[i], GISTrainer.this.values[i], dArray, GISTrainer.this.evalParams);
                } else {
                    GISTrainer.this.prior.logPrior(dArray, GISTrainer.this.contexts[i]);
                    GISModel.eval(GISTrainer.this.contexts[i], dArray, GISTrainer.this.evalParams);
                }
                for (n2 = 0; n2 < GISTrainer.this.contexts[i].length; ++n2) {
                    n = GISTrainer.this.contexts[i][n2];
                    if (GISTrainer.this.predicateCounts[n] < GISTrainer.this.cutoff) continue;
                    int[] nArray = GISTrainer.this.modelExpects[this.threadIndex][n].getOutcomes();
                    for (int j = 0; j < nArray.length; ++j) {
                        int n3 = nArray[j];
                        if (GISTrainer.this.values != null && GISTrainer.this.values[i] != null) {
                            GISTrainer.this.modelExpects[this.threadIndex][n].updateParameter(j, dArray[n3] * (double)GISTrainer.this.values[i][n2] * (double)GISTrainer.this.numTimesEventsSeen[i]);
                            continue;
                        }
                        GISTrainer.this.modelExpects[this.threadIndex][n].updateParameter(j, dArray[n3] * (double)GISTrainer.this.numTimesEventsSeen[i]);
                    }
                }
                this.loglikelihood += Math.log(dArray[GISTrainer.this.outcomeList[i]]) * (double)GISTrainer.this.numTimesEventsSeen[i];
                this.numEvents += GISTrainer.this.numTimesEventsSeen[i];
                if (!GISTrainer.this.printMessages) continue;
                n2 = 0;
                for (n = 1; n < GISTrainer.this.numOutcomes; ++n) {
                    if (!(dArray[n] > dArray[n2])) continue;
                    n2 = n;
                }
                if (n2 != GISTrainer.this.outcomeList[i]) continue;
                this.numCorrect += GISTrainer.this.numTimesEventsSeen[i];
            }
            return this;
        }

        synchronized int getNumEvents() {
            return this.numEvents;
        }

        synchronized int getNumCorrect() {
            return this.numCorrect;
        }

        synchronized double getLoglikelihood() {
            return this.loglikelihood;
        }
    }
}

