/*
 * Decompiled with CFR 0.152.
 */
package nak.perceptron;

import nak.core.AbstractModel;
import nak.core.Context;
import nak.core.EvalParameters;
import nak.core.MutableContext;
import nak.data.DataIndexer;
import nak.perceptron.PerceptronModel;

public class PerceptronTrainer {
    public static final double TOLERANCE_DEFAULT = 1.0E-5;
    private int numUniqueEvents;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;
    private boolean printMessages = true;
    private double tolerance = 1.0E-5;
    private Double stepSizeDecrease;
    private boolean useSkippedlAveraging;

    public void setTolerance(double d) {
        if (d < 0.0) {
            throw new IllegalArgumentException("tolerance must be a positive number but is " + d + "!");
        }
        this.tolerance = d;
    }

    public void setStepSizeDecrease(double d) {
        if (d < 0.0 || d > 100.0) {
            throw new IllegalArgumentException("decrease must be between 0 and 100 but is " + d + "!");
        }
        this.stepSizeDecrease = d;
    }

    public void setSkippedAveraging(boolean bl) {
        this.useSkippedlAveraging = bl;
    }

    public AbstractModel trainModel(int n, DataIndexer dataIndexer, int n2) {
        return this.trainModel(n, dataIndexer, n2, true);
    }

    public AbstractModel trainModel(int n, DataIndexer dataIndexer, int n2, boolean bl) {
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = dataIndexer.getContexts();
        this.values = dataIndexer.getValues();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numEvents = dataIndexer.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.predLabels = dataIndexer.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        this.display("done.\n");
        this.display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        this.display("Computing model parameters...\n");
        Context[] contextArray = this.findParameters(n, bl);
        this.display("...done.\n");
        return new PerceptronModel(contextArray, this.predLabels, this.outcomeLabels);
    }

    private MutableContext[] findParameters(int n, boolean bl) {
        int n2;
        int n3;
        this.display("Performing " + n + " iterations.\n");
        int[] nArray = new int[this.numOutcomes];
        for (int i = 0; i < this.numOutcomes; ++i) {
            nArray[i] = i;
        }
        Context[] contextArray = new MutableContext[this.numPreds];
        for (int i = 0; i < this.numPreds; ++i) {
            contextArray[i] = new MutableContext(nArray, new double[this.numOutcomes]);
            for (int j = 0; j < this.numOutcomes; ++j) {
                ((MutableContext)contextArray[i]).setParameter(j, 0.0);
            }
        }
        EvalParameters evalParameters = new EvalParameters(contextArray, this.numOutcomes);
        MutableContext[] mutableContextArray = new MutableContext[this.numPreds];
        if (bl) {
            for (int i = 0; i < this.numPreds; ++i) {
                mutableContextArray[i] = new MutableContext(nArray, new double[this.numOutcomes]);
                for (int j = 0; j < this.numOutcomes; ++j) {
                    mutableContextArray[i].setParameter(j, 0.0);
                }
            }
        }
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        int n4 = 0;
        double d4 = 1.0;
        for (n3 = 1; n3 <= n; ++n3) {
            int n5;
            int n6;
            if (this.stepSizeDecrease != null) {
                d4 *= 1.0 - this.stepSizeDecrease;
            }
            this.displayIteration(n3);
            n2 = 0;
            for (int i = 0; i < this.numUniqueEvents; ++i) {
                int n7 = this.outcomeList[i];
                for (n6 = 0; n6 < this.numTimesEventsSeen[i]; ++n6) {
                    double[] dArray = new double[this.numOutcomes];
                    if (this.values != null) {
                        PerceptronModel.eval(this.contexts[i], this.values[i], dArray, evalParameters, false);
                    } else {
                        PerceptronModel.eval(this.contexts[i], null, dArray, evalParameters, false);
                    }
                    n5 = this.maxIndex(dArray);
                    if (n5 != n7) {
                        for (int j = 0; j < this.contexts[i].length; ++j) {
                            int n8 = this.contexts[i][j];
                            if (this.values == null) {
                                ((MutableContext)contextArray[n8]).updateParameter(n7, d4);
                                ((MutableContext)contextArray[n8]).updateParameter(n5, -d4);
                                continue;
                            }
                            ((MutableContext)contextArray[n8]).updateParameter(n7, d4 * (double)this.values[i][j]);
                            ((MutableContext)contextArray[n8]).updateParameter(n5, -d4 * (double)this.values[i][j]);
                        }
                    }
                    if (n5 != n7) continue;
                    ++n2;
                }
            }
            double d5 = (double)n2 / (double)this.numEvents;
            if (n3 < 10 || n3 % 10 == 0) {
                this.display(". (" + n2 + "/" + this.numEvents + ") " + d5 + "\n");
            }
            if ((n6 = bl && this.useSkippedlAveraging && (n3 < 20 || PerceptronTrainer.isPerfectSquare(n3)) ? 1 : (bl ? 1 : 0)) != 0) {
                ++n4;
                for (int i = 0; i < this.numPreds; ++i) {
                    for (n5 = 0; n5 < this.numOutcomes; ++n5) {
                        mutableContextArray[i].updateParameter(n5, contextArray[i].getParameters()[n5]);
                    }
                }
            }
            if (Math.abs(d - d5) < this.tolerance && Math.abs(d2 - d5) < this.tolerance && Math.abs(d3 - d5) < this.tolerance) {
                this.display("Stopping: change in training set accuracy less than " + this.tolerance + "\n");
                break;
            }
            d = d2;
            d2 = d3;
            d3 = d5;
        }
        this.trainingStats(evalParameters);
        if (bl) {
            for (n3 = 0; n3 < this.numPreds; ++n3) {
                for (n2 = 0; n2 < this.numOutcomes; ++n2) {
                    mutableContextArray[n3].setParameter(n2, mutableContextArray[n3].getParameters()[n2] / (double)n4);
                }
            }
            return mutableContextArray;
        }
        return contextArray;
    }

    private double trainingStats(EvalParameters evalParameters) {
        int n = 0;
        for (int i = 0; i < this.numUniqueEvents; ++i) {
            for (int j = 0; j < this.numTimesEventsSeen[i]; ++j) {
                double[] dArray = new double[this.numOutcomes];
                if (this.values != null) {
                    PerceptronModel.eval(this.contexts[i], this.values[i], dArray, evalParameters, false);
                } else {
                    PerceptronModel.eval(this.contexts[i], null, dArray, evalParameters, false);
                }
                int n2 = this.maxIndex(dArray);
                if (n2 != this.outcomeList[i]) continue;
                ++n;
            }
        }
        double d = (double)n / (double)this.numEvents;
        this.display("Stats: (" + n + "/" + this.numEvents + ") " + d + "\n");
        return d;
    }

    private int maxIndex(double[] dArray) {
        int n = 0;
        for (int i = 1; i < dArray.length; ++i) {
            if (!(dArray[i] > dArray[n])) continue;
            n = i;
        }
        return n;
    }

    private void display(String string) {
        if (this.printMessages) {
            System.out.print(string);
        }
    }

    private void displayIteration(int n) {
        if (n > 10 && n % 10 != 0) {
            return;
        }
        if (n < 10) {
            this.display("  " + n + ":  ");
        } else if (n < 100) {
            this.display(" " + n + ":  ");
        } else {
            this.display(n + ":  ");
        }
    }

    private static final boolean isPerfectSquare(int n) {
        int n2 = (int)Math.sqrt(n);
        return n2 * n2 == n;
    }
}

