/*
 * Decompiled with CFR 0.152.
 */
package opennlp.maxent.quasinewton;

import java.util.Arrays;
import opennlp.maxent.quasinewton.ArrayMath;
import opennlp.maxent.quasinewton.DifferentiableFunction;
import opennlp.maxent.quasinewton.LineSearch;
import opennlp.maxent.quasinewton.LineSearchResult;
import opennlp.maxent.quasinewton.LogLikelihoodFunction;
import opennlp.maxent.quasinewton.QNModel;
import opennlp.model.DataIndexer;

public class QNTrainer {
    private static final double CONVERGE_TOLERANCE = 1.0E-10;
    private static final int MAX_M = 15;
    public static final int DEFAULT_M = 7;
    public static final int MAX_FCT_EVAL = 3000;
    public static final int DEFAULT_MAX_FCT_EVAL = 300;
    private int dimension;
    private int m;
    private int maxFctEval;
    private QNInfo updateInfo;
    private boolean verbose;

    public QNTrainer() {
        this(true);
    }

    public QNTrainer(boolean verbose) {
        this(7, verbose);
    }

    public QNTrainer(int m) {
        this(m, true);
    }

    public QNTrainer(int m, boolean verbose) {
        this(m, 300, verbose);
    }

    public QNTrainer(int m, int maxFctEval, boolean verbose) {
        this.verbose = verbose;
        this.m = m > 15 ? 15 : m;
        this.maxFctEval = maxFctEval < 0 ? 300 : (maxFctEval > 3000 ? 3000 : maxFctEval);
    }

    public QNModel trainModel(DataIndexer indexer) {
        LogLikelihoodFunction objectiveFunction = this.generateFunction(indexer);
        this.dimension = objectiveFunction.getDomainDimension();
        this.updateInfo = new QNInfo(this.m, this.dimension);
        double[] initialPoint = objectiveFunction.getInitialPoint();
        double initialValue = objectiveFunction.valueAt(initialPoint);
        double[] initialGrad = objectiveFunction.gradientAt(initialPoint);
        LineSearchResult lsr = LineSearchResult.getInitialObject(initialValue, initialGrad, initialPoint, 0);
        int z = 0;
        do {
            if (this.verbose) {
                System.out.print(z++);
            }
            double[] direction = null;
            direction = this.computeDirection(objectiveFunction, lsr);
            lsr = LineSearch.doLineSearch(objectiveFunction, direction, lsr, this.verbose);
            this.updateInfo.updateInfo(lsr);
        } while (!this.isConverged(lsr));
        return new QNModel(objectiveFunction, lsr.getNextPoint());
    }

    private LogLikelihoodFunction generateFunction(DataIndexer indexer) {
        return new LogLikelihoodFunction(indexer);
    }

    private double[] computeDirection(DifferentiableFunction monitor, LineSearchResult lsr) {
        int i;
        double[] direction = (double[])lsr.getGradAtNext().clone();
        double[] as = new double[this.m];
        for (i = this.updateInfo.kCounter - 1; i >= 0; --i) {
            as[i] = this.updateInfo.getRho(i) * ArrayMath.innerProduct(this.updateInfo.getS(i), direction);
            for (int ii = 0; ii < this.dimension; ++ii) {
                direction[ii] = direction[ii] - as[i] * this.updateInfo.getY(i)[ii];
            }
        }
        for (i = 0; i < this.updateInfo.kCounter; ++i) {
            double b = this.updateInfo.getRho(i) * ArrayMath.innerProduct(this.updateInfo.getY(i), direction);
            for (int ii = 0; ii < this.dimension; ++ii) {
                direction[ii] = direction[ii] + (as[i] - b) * this.updateInfo.getS(i)[ii];
            }
        }
        i = 0;
        while (i < this.dimension) {
            int n = i++;
            direction[n] = direction[n] * -1.0;
        }
        return direction;
    }

    private boolean isConverged(LineSearchResult lsr) {
        return 1.0E-10 > Math.abs(lsr.getValueAtNext() - lsr.getValueAtCurr()) || lsr.getFctEvalCount() > this.maxFctEval;
    }

    private class QNInfo {
        private double[][] S;
        private double[][] Y;
        private double[] rho;
        private int m;
        private double[] diagonal;
        private int kCounter;

        QNInfo(int numCorrection, int dimension) {
            this.m = numCorrection;
            this.kCounter = 0;
            this.S = new double[this.m][];
            this.Y = new double[this.m][];
            this.rho = new double[this.m];
            Arrays.fill(this.rho, Double.NaN);
            this.diagonal = new double[dimension];
            Arrays.fill(this.diagonal, 1.0);
        }

        public void updateInfo(LineSearchResult lsr) {
            double[] s_k = new double[QNTrainer.this.dimension];
            double[] y_k = new double[QNTrainer.this.dimension];
            for (int i = 0; i < QNTrainer.this.dimension; ++i) {
                s_k[i] = lsr.getNextPoint()[i] - lsr.getCurrPoint()[i];
                y_k[i] = lsr.getGradAtNext()[i] - lsr.getGradAtCurr()[i];
            }
            this.updateSYRoh(s_k, y_k);
            this.kCounter = this.kCounter < this.m ? this.kCounter + 1 : this.kCounter;
        }

        private void updateSYRoh(double[] s_k, double[] y_k) {
            double newRoh = 1.0 / ArrayMath.innerProduct(y_k, s_k);
            if (this.kCounter < this.m) {
                this.S[this.kCounter] = (double[])s_k.clone();
                this.Y[this.kCounter] = (double[])y_k.clone();
                this.rho[this.kCounter] = newRoh;
            } else if (this.m > 0) {
                for (int i = 0; i < this.m - 1; ++i) {
                    this.S[i] = this.S[i + 1];
                    this.Y[i] = this.Y[i + 1];
                    this.rho[i] = this.rho[i + 1];
                }
                this.S[this.m - 1] = (double[])s_k.clone();
                this.Y[this.m - 1] = (double[])y_k.clone();
                this.rho[this.m - 1] = newRoh;
            }
        }

        public double getRho(int updateIndex) {
            return this.rho[updateIndex];
        }

        public double[] getS(int updateIndex) {
            return this.S[updateIndex];
        }

        public double[] getY(int updateIndex) {
            return this.Y[updateIndex];
        }
    }
}

