/*
 * Decompiled with CFR 0.152.
 */
package nak.perceptron;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import nak.core.AbstractModel;
import nak.core.Context;
import nak.core.Event;
import nak.core.MutableContext;
import nak.core.Sequence;
import nak.data.OnePassDataIndexer;
import nak.data.SequenceStream;
import nak.data.SequenceStreamEventStream;
import nak.perceptron.PerceptronModel;
import nak.util.IndexHashTable;

public class SimplePerceptronSequenceTrainer {
    private boolean printMessages = true;
    private int iterations;
    private SequenceStream sequenceStream;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[] outcomeList;
    private String[] outcomeLabels;
    double[] modelDistribution;
    private MutableContext[] averageParams;
    private IndexHashTable<String> pmap;
    private Map<String, Integer> omap;
    private MutableContext[] params;
    private boolean useAverage;
    private int[][][] updates;
    private int VALUE = 0;
    private int ITER = 1;
    private int EVENT = 2;
    private int[] allOutcomesPattern;
    private String[] predLabels;
    int numSequences;

    public AbstractModel trainModel(int n, SequenceStream sequenceStream, int n2, boolean bl) throws IOException {
        int n3;
        this.iterations = n;
        this.sequenceStream = sequenceStream;
        OnePassDataIndexer onePassDataIndexer = new OnePassDataIndexer(new SequenceStreamEventStream(sequenceStream), n2, false);
        this.numSequences = 0;
        for (Sequence sequence : sequenceStream) {
            ++this.numSequences;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.predLabels = onePassDataIndexer.getPredLabels();
        this.pmap = new IndexHashTable<String>(this.predLabels, 0.7);
        this.display("Incorporating indexed data for training...  \n");
        this.useAverage = bl;
        this.numEvents = onePassDataIndexer.getNumEvents();
        this.iterations = n;
        this.outcomeLabels = onePassDataIndexer.getOutcomeLabels();
        this.omap = new HashMap<String, Integer>();
        for (n3 = 0; n3 < this.outcomeLabels.length; ++n3) {
            this.omap.put(this.outcomeLabels[n3], n3);
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        if (bl) {
            this.updates = new int[this.numPreds][this.numOutcomes][3];
        }
        this.display("done.\n");
        this.display("\tNumber of Event Tokens: " + this.numEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        this.params = new MutableContext[this.numPreds];
        if (bl) {
            this.averageParams = new MutableContext[this.numPreds];
        }
        this.allOutcomesPattern = new int[this.numOutcomes];
        for (n3 = 0; n3 < this.numOutcomes; ++n3) {
            this.allOutcomesPattern[n3] = n3;
        }
        for (n3 = 0; n3 < this.numPreds; ++n3) {
            this.params[n3] = new MutableContext(this.allOutcomesPattern, new double[this.numOutcomes]);
            if (bl) {
                this.averageParams[n3] = new MutableContext(this.allOutcomesPattern, new double[this.numOutcomes]);
            }
            for (int i = 0; i < this.numOutcomes; ++i) {
                this.params[n3].setParameter(i, 0.0);
                if (!bl) continue;
                this.averageParams[n3].setParameter(i, 0.0);
            }
        }
        this.modelDistribution = new double[this.numOutcomes];
        this.display("Computing model parameters...\n");
        this.findParameters(n);
        this.display("...done.\n");
        String[] stringArray = this.predLabels;
        if (bl) {
            return new PerceptronModel(this.averageParams, stringArray, this.outcomeLabels);
        }
        return new PerceptronModel(this.params, stringArray, this.outcomeLabels);
    }

    private void findParameters(int n) {
        this.display("Performing " + n + " iterations.\n");
        for (int i = 1; i <= n; ++i) {
            if (i < 10) {
                this.display("  " + i + ":  ");
            } else if (i < 100) {
                this.display(" " + i + ":  ");
            } else {
                this.display(i + ":  ");
            }
            this.nextIteration(i);
        }
        if (this.useAverage) {
            this.trainingStats(this.averageParams);
        } else {
            this.trainingStats(this.params);
        }
    }

    private void display(String string) {
        if (this.printMessages) {
            System.out.print(string);
        }
    }

    public void nextIteration(int n) {
        int n2;
        Object[] objectArray;
        --n;
        int n3 = 0;
        int n4 = 0;
        int n5 = 0;
        Map[] mapArray = new Map[this.numOutcomes];
        for (int i = 0; i < this.numOutcomes; ++i) {
            mapArray[i] = new HashMap();
        }
        PerceptronModel perceptronModel = new PerceptronModel((Context[])this.params, this.predLabels, this.pmap, this.outcomeLabels);
        for (Sequence sequence : this.sequenceStream) {
            Event[] eventArray = this.sequenceStream.updateContext(sequence, perceptronModel);
            objectArray = sequence.getEvents();
            n2 = 0;
            int n6 = 0;
            while (n6 < objectArray.length) {
                if (!eventArray[n6].getOutcome().equals(objectArray[n6].getOutcome())) {
                    n2 = 1;
                } else {
                    ++n3;
                }
                ++n6;
                ++n4;
            }
            if (n2 != 0) {
                for (n6 = 0; n6 < this.numOutcomes; ++n6) {
                    mapArray[n6].clear();
                }
                n6 = 0;
                while (n6 < objectArray.length) {
                    String[] stringArray = objectArray[n6].getContext();
                    float[] fArray = objectArray[n6].getValues();
                    int n7 = this.omap.get(objectArray[n6].getOutcome());
                    for (int i = 0; i < stringArray.length; ++i) {
                        Float f;
                        float f2 = 1.0f;
                        if (fArray != null) {
                            f2 = fArray[i];
                        }
                        f = (f = (Float)mapArray[n7].get(stringArray[i])) == null ? Float.valueOf(f2) : Float.valueOf(f.floatValue() + f2);
                        mapArray[n7].put(stringArray[i], f);
                    }
                    ++n6;
                    ++n4;
                }
                for (Event event : eventArray) {
                    String[] stringArray = event.getContext();
                    float[] fArray = event.getValues();
                    int n8 = this.omap.get(event.getOutcome());
                    for (int i = 0; i < stringArray.length; ++i) {
                        Float f;
                        float f3 = 1.0f;
                        if (fArray != null) {
                            f3 = fArray[i];
                        }
                        if ((f = (f = (Float)mapArray[n8].get(stringArray[i])) == null ? Float.valueOf(-1.0f * f3) : Float.valueOf(f.floatValue() - f3)).floatValue() == 0.0f) {
                            mapArray[n8].remove(stringArray[i]);
                            continue;
                        }
                        mapArray[n8].put(stringArray[i], f);
                    }
                }
                for (n6 = 0; n6 < this.numOutcomes; ++n6) {
                    for (String string : mapArray[n6].keySet()) {
                        int n9 = this.pmap.get(string);
                        if (n9 == -1) continue;
                        this.params[n9].updateParameter(n6, ((Float)mapArray[n6].get(string)).floatValue());
                        if (!this.useAverage) continue;
                        if (this.updates[n9][n6][this.VALUE] != 0) {
                            this.averageParams[n9].updateParameter(n6, this.updates[n9][n6][this.VALUE] * (this.numSequences * (n - this.updates[n9][n6][this.ITER]) + (n5 - this.updates[n9][n6][this.EVENT])));
                        }
                        this.updates[n9][n6][this.VALUE] = (int)this.params[n9].getParameters()[n6];
                        this.updates[n9][n6][this.ITER] = n;
                        this.updates[n9][n6][this.EVENT] = n5;
                    }
                }
                perceptronModel = new PerceptronModel((Context[])this.params, this.predLabels, this.pmap, this.outcomeLabels);
            }
            ++n5;
        }
        double d = (double)this.iterations * (double)n5;
        if (this.useAverage && n == this.iterations - 1) {
            for (int i = 0; i < this.numPreds; ++i) {
                objectArray = this.averageParams[i].getParameters();
                for (n2 = 0; n2 < this.numOutcomes; ++n2) {
                    if (this.updates[i][n2][this.VALUE] != 0) {
                        int n10 = n2;
                        objectArray[n10] = objectArray[n10] + (double)(this.updates[i][n2][this.VALUE] * (this.numSequences * (this.iterations - this.updates[i][n2][this.ITER]) - this.updates[i][n2][this.EVENT]));
                    }
                    if (objectArray[n2] == 0.0) continue;
                    int n11 = n2;
                    objectArray[n11] = objectArray[n11] / d;
                    this.averageParams[i].setParameter(n2, (double)objectArray[n2]);
                }
            }
        }
        this.display(". (" + n3 + "/" + this.numEvents + ") " + (double)n3 / (double)this.numEvents + "\n");
    }

    private void trainingStats(MutableContext[] mutableContextArray) {
        int n = 0;
        int n2 = 0;
        for (Sequence sequence : this.sequenceStream) {
            Event[] eventArray = this.sequenceStream.updateContext(sequence, new PerceptronModel((Context[])mutableContextArray, this.predLabels, this.pmap, this.outcomeLabels));
            int n3 = 0;
            while (n3 < eventArray.length) {
                int n4 = this.omap.get(eventArray[n3].getOutcome());
                if (n4 == this.outcomeList[n2]) {
                    ++n;
                }
                ++n3;
                ++n2;
            }
        }
        this.display(". (" + n + "/" + this.numEvents + ") " + (double)n / (double)this.numEvents + "\n");
    }
}

