/*
 * Decompiled with CFR 0.152.
 */
package nak.quasinewton;

import java.util.Arrays;
import nak.data.DataIndexer;
import nak.quasinewton.ArrayMath;
import nak.quasinewton.DifferentiableFunction;
import nak.quasinewton.LineSearch;
import nak.quasinewton.LineSearchResult;
import nak.quasinewton.LogLikelihoodFunction;
import nak.quasinewton.QNModel;

public class QNTrainer {
    private static final double CONVERGE_TOLERANCE = 1.0E-10;
    private static final int MAX_M = 15;
    public static final int DEFAULT_M = 7;
    public static final int MAX_FCT_EVAL = 3000;
    public static final int DEFAULT_MAX_FCT_EVAL = 300;
    private int dimension;
    private int m;
    private int maxFctEval;
    private QNInfo updateInfo;
    private boolean verbose;

    public QNTrainer() {
        this(true);
    }

    public QNTrainer(boolean bl) {
        this(7, bl);
    }

    public QNTrainer(int n) {
        this(n, true);
    }

    public QNTrainer(int n, boolean bl) {
        this(n, 300, bl);
    }

    public QNTrainer(int n, int n2, boolean bl) {
        this.verbose = bl;
        this.m = n > 15 ? 15 : n;
        this.maxFctEval = n2 < 0 ? 300 : (n2 > 3000 ? 3000 : n2);
    }

    public QNModel trainModel(DataIndexer dataIndexer) {
        LogLikelihoodFunction logLikelihoodFunction = this.generateFunction(dataIndexer);
        this.dimension = logLikelihoodFunction.getDomainDimension();
        this.updateInfo = new QNInfo(this.m, this.dimension);
        double[] dArray = logLikelihoodFunction.getInitialPoint();
        double d = logLikelihoodFunction.valueAt(dArray);
        double[] dArray2 = logLikelihoodFunction.gradientAt(dArray);
        LineSearchResult lineSearchResult = LineSearchResult.getInitialObject(d, dArray2, dArray, 0);
        int n = 0;
        do {
            if (this.verbose) {
                System.out.print(n++);
            }
            double[] dArray3 = null;
            dArray3 = this.computeDirection(logLikelihoodFunction, lineSearchResult);
            lineSearchResult = LineSearch.doLineSearch(logLikelihoodFunction, dArray3, lineSearchResult, this.verbose);
            this.updateInfo.updateInfo(lineSearchResult);
        } while (!this.isConverged(lineSearchResult));
        return new QNModel(logLikelihoodFunction, lineSearchResult.getNextPoint());
    }

    private LogLikelihoodFunction generateFunction(DataIndexer dataIndexer) {
        return new LogLikelihoodFunction(dataIndexer);
    }

    private double[] computeDirection(DifferentiableFunction differentiableFunction, LineSearchResult lineSearchResult) {
        int n;
        double[] dArray = (double[])lineSearchResult.getGradAtNext().clone();
        double[] dArray2 = new double[this.m];
        for (n = this.updateInfo.kCounter - 1; n >= 0; --n) {
            dArray2[n] = this.updateInfo.getRho(n) * ArrayMath.innerProduct(this.updateInfo.getS(n), dArray);
            for (int i = 0; i < this.dimension; ++i) {
                dArray[i] = dArray[i] - dArray2[n] * this.updateInfo.getY(n)[i];
            }
        }
        for (n = 0; n < this.updateInfo.kCounter; ++n) {
            double d = this.updateInfo.getRho(n) * ArrayMath.innerProduct(this.updateInfo.getY(n), dArray);
            for (int i = 0; i < this.dimension; ++i) {
                dArray[i] = dArray[i] + (dArray2[n] - d) * this.updateInfo.getS(n)[i];
            }
        }
        n = 0;
        while (n < this.dimension) {
            int n2 = n++;
            dArray[n2] = dArray[n2] * -1.0;
        }
        return dArray;
    }

    private boolean isConverged(LineSearchResult lineSearchResult) {
        return 1.0E-10 > Math.abs(lineSearchResult.getValueAtNext() - lineSearchResult.getValueAtCurr()) || lineSearchResult.getFctEvalCount() > this.maxFctEval;
    }

    private class QNInfo {
        private double[][] S;
        private double[][] Y;
        private double[] rho;
        private int m;
        private double[] diagonal;
        private int kCounter;

        QNInfo(int n, int n2) {
            this.m = n;
            this.kCounter = 0;
            this.S = new double[this.m][];
            this.Y = new double[this.m][];
            this.rho = new double[this.m];
            Arrays.fill(this.rho, Double.NaN);
            this.diagonal = new double[n2];
            Arrays.fill(this.diagonal, 1.0);
        }

        public void updateInfo(LineSearchResult lineSearchResult) {
            double[] dArray = new double[QNTrainer.this.dimension];
            double[] dArray2 = new double[QNTrainer.this.dimension];
            for (int i = 0; i < QNTrainer.this.dimension; ++i) {
                dArray[i] = lineSearchResult.getNextPoint()[i] - lineSearchResult.getCurrPoint()[i];
                dArray2[i] = lineSearchResult.getGradAtNext()[i] - lineSearchResult.getGradAtCurr()[i];
            }
            this.updateSYRoh(dArray, dArray2);
            this.kCounter = this.kCounter < this.m ? this.kCounter + 1 : this.kCounter;
        }

        private void updateSYRoh(double[] dArray, double[] dArray2) {
            double d = 1.0 / ArrayMath.innerProduct(dArray2, dArray);
            if (this.kCounter < this.m) {
                this.S[this.kCounter] = (double[])dArray.clone();
                this.Y[this.kCounter] = (double[])dArray2.clone();
                this.rho[this.kCounter] = d;
            } else if (this.m > 0) {
                for (int i = 0; i < this.m - 1; ++i) {
                    this.S[i] = this.S[i + 1];
                    this.Y[i] = this.Y[i + 1];
                    this.rho[i] = this.rho[i + 1];
                }
                this.S[this.m - 1] = (double[])dArray.clone();
                this.Y[this.m - 1] = (double[])dArray2.clone();
                this.rho[this.m - 1] = d;
            }
        }

        public double getRho(int n) {
            return this.rho[n];
        }

        public double[] getS(int n) {
            return this.S[n];
        }

        public double[] getY(int n) {
            return this.Y[n];
        }
    }
}

