/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.Arrays;
import java.util.List;

public class LinearRegRank
extends Ranker {
    public static double lambda = 1.0E-10;
    protected double[] weight = null;

    public LinearRegRank() {
    }

    public LinearRegRank(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        this.PRINTLN("Initializing... [Done]");
    }

    @Override
    public void learn() {
        int i;
        this.PRINTLN("--------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("--------------------------------");
        this.PRINT("Learning the least square model... ");
        int nVar = 0;
        for (RankList rl : this.samples) {
            for (i = 0; i < rl.size(); ++i) {
                DataPoint dp = rl.get(i);
                if (nVar >= dp.getNumberOfKnownFeatures()) continue;
                nVar = dp.getNumberOfKnownFeatures();
            }
        }
        double[][] xTx = new double[nVar][];
        for (int i2 = 0; i2 < nVar; ++i2) {
            xTx[i2] = new double[nVar];
            Arrays.fill(xTx[i2], 0.0);
        }
        double[] xTy = new double[nVar];
        Arrays.fill(xTy, 0.0);
        for (int s = 0; s < this.samples.size(); ++s) {
            RankList rl = (RankList)this.samples.get(s);
            for (int i3 = 0; i3 < rl.size(); ++i3) {
                int n = nVar - 1;
                xTy[n] = xTy[n] + (double)rl.get(i3).getLabel();
                for (int j = 0; j < nVar - 1; ++j) {
                    int n2 = j;
                    xTy[n2] = xTy[n2] + (double)(rl.get(i3).getFeatureValue(j + 1) * rl.get(i3).getLabel());
                    int k = 0;
                    while (k < nVar) {
                        double t = k < nVar - 1 ? (double)rl.get(i3).getFeatureValue(k + 1) : 1.0;
                        double[] dArray = xTx[j];
                        int n3 = k++;
                        dArray[n3] = dArray[n3] + (double)rl.get(i3).getFeatureValue(j + 1) * t;
                    }
                }
                for (int k = 0; k < nVar - 1; ++k) {
                    double[] dArray = xTx[nVar - 1];
                    int n4 = k;
                    dArray[n4] = dArray[n4] + (double)rl.get(i3).getFeatureValue(k + 1);
                }
                double[] dArray = xTx[nVar - 1];
                int n5 = nVar - 1;
                dArray[n5] = dArray[n5] + 1.0;
            }
        }
        if (lambda != 0.0) {
            i = 0;
            while (i < xTx.length) {
                double[] dArray = xTx[i];
                int n = i++;
                dArray[n] = dArray[n] + lambda;
            }
        }
        this.weight = this.solve(xTx, xTy);
        this.PRINTLN("[Done]");
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("---------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint p) {
        double score = this.weight[this.weight.length - 1];
        for (int i = 0; i < this.features.length; ++i) {
            score += this.weight[i] * (double)p.getFeatureValue(this.features[i]);
        }
        return score;
    }

    @Override
    public Ranker createNew() {
        return new LinearRegRank();
    }

    @Override
    public String toString() {
        String output = "0:" + this.weight[0] + " ";
        for (int i = 0; i < this.features.length; ++i) {
            output = output + this.features[i] + ":" + this.weight[i] + (i == this.weight.length - 1 ? "" : " ");
        }
        return output;
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## Lambda = " + lambda + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            String content = "";
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            KeyValuePair kvp = null;
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                kvp = new KeyValuePair(content);
                break;
            }
            in.close();
            assert (kvp != null);
            List<String> keys = kvp.keys();
            List<String> values = kvp.values();
            this.weight = new double[keys.size()];
            this.features = new int[keys.size() - 1];
            int idx = 0;
            for (int i = 0; i < keys.size(); ++i) {
                int fid = Integer.parseInt(keys.get(i));
                if (fid > 0) {
                    this.features[idx] = fid;
                    this.weight[idx] = Double.parseDouble(values.get(i));
                    ++idx;
                    continue;
                }
                this.weight[this.weight.length - 1] = Double.parseDouble(values.get(i));
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in LinearRegRank::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("L2-norm regularization: lambda = " + lambda);
    }

    @Override
    public String name() {
        return "Linear Regression";
    }

    protected double[] solve(double[][] A, double[] B) {
        if (A.length == 0 || B.length == 0) {
            System.out.println("Error: some of the input arrays is empty.");
            System.exit(1);
        }
        if (A[0].length == 0) {
            System.out.println("Error: some of the input arrays is empty.");
            System.exit(1);
        }
        if (A.length != B.length) {
            System.out.println("Error: Solving Ax=B: A and B have different dimension.");
            System.exit(1);
        }
        double[][] a = new double[A.length][];
        double[] b = new double[B.length];
        System.arraycopy(B, 0, b, 0, B.length);
        for (int i = 0; i < a.length; ++i) {
            a[i] = new double[A[i].length];
            if (i > 0 && a[i].length != a[i - 1].length) {
                System.out.println("Error: Solving Ax=B: A is NOT a square matrix.");
                System.exit(1);
            }
            System.arraycopy(A[i], 0, a[i], 0, A[i].length);
        }
        double pivot = 0.0;
        double multiplier = 0.0;
        for (int j = 0; j < b.length - 1; ++j) {
            pivot = a[j][j];
            int i = j + 1;
            while (i < b.length) {
                multiplier = a[i][j] / pivot;
                for (int k = j + 1; k < b.length; ++k) {
                    double[] dArray = a[i];
                    int n = k;
                    dArray[n] = dArray[n] - a[j][k] * multiplier;
                }
                int n = i++;
                b[n] = b[n] - b[j] * multiplier;
            }
        }
        double[] x = new double[b.length];
        int n = b.length;
        x[n - 1] = b[n - 1] / a[n - 1][n - 1];
        for (int i = n - 2; i >= 0; --i) {
            double val = b[i];
            for (int j = i + 1; j < n; ++j) {
                val -= a[i][j] * x[j];
            }
            x[i] = val / a[i][i];
        }
        return x;
    }
}

