/*
 * Decompiled with CFR 0.152.
 */
package es.uam.eps.ir.ranksys.mf.als;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.EigenvalueDecomposition;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import es.uam.eps.ir.ranksys.fast.preference.IdxPref;
import es.uam.eps.ir.ranksys.fast.preference.TransposedPreferenceData;
import es.uam.eps.ir.ranksys.mf.als.ALSFactorizer;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Stream;

public class PZTFactorizer<U, I>
extends ALSFactorizer<U, I> {
    private final double lambdaP;
    private final double lambdaQ;
    private final DoubleUnaryOperator confidence;

    public PZTFactorizer(double lambda, DoubleUnaryOperator confidence, int numIter) {
        this(lambda, lambda, confidence, numIter);
    }

    public PZTFactorizer(double lambdaP, double lambdaQ, DoubleUnaryOperator confidence, int numIter) {
        super(numIter);
        this.lambdaP = lambdaP;
        this.lambdaQ = lambdaQ;
        this.confidence = confidence;
    }

    @Override
    public double error(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, FastPreferenceData<U, I> data) {
        return data.getUidxWithPreferences().parallel().mapToDouble(uidx -> {
            DoubleMatrix1D pu = p.viewRow(uidx);
            DoubleMatrix1D su = q.zMult(pu, null);
            double err1 = data.getUidxPreferences(uidx).mapToDouble(iv -> {
                double rui = iv.v2;
                double sui = su.getQuick(iv.v1);
                double cui = this.confidence.applyAsDouble(rui);
                return cui * (rui - sui) * (rui - sui) - this.confidence.applyAsDouble(0.0) * sui * sui;
            }).sum();
            double err2 = this.confidence.applyAsDouble(0.0) * su.assign(x -> x * x).zSum();
            return (err1 + err2) / (double)data.numItems();
        }).sum() / (double)data.numUsers();
    }

    @Override
    public void set_minP(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, FastPreferenceData<U, I> data) {
        PZTFactorizer.set_min(p, q, this.confidence, this.lambdaP, data);
    }

    @Override
    public void set_minQ(DenseDoubleMatrix2D q, DenseDoubleMatrix2D p, FastPreferenceData<U, I> data) {
        PZTFactorizer.set_min(q, p, this.confidence, this.lambdaQ, new TransposedPreferenceData(data));
    }

    private static <U, I> void set_min(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, DoubleUnaryOperator confidence, double lambda, FastPreferenceData<U, I> data) {
        DoubleMatrix2D gt = PZTFactorizer.getGt(p, q, lambda);
        data.getUidxWithPreferences().parallel().forEach(uidx -> PZTFactorizer.prepareRR1(1, p.viewRow(uidx), gt, (DoubleMatrix2D)q, data.numItems(uidx), data.getUidxPreferences(uidx), confidence, lambda));
    }

    private static DoubleMatrix2D getGt(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, double lambda) {
        int K = p.columns();
        DenseDoubleMatrix2D A1 = new DenseDoubleMatrix2D(K, K);
        q.zMult((DoubleMatrix2D)q, (DoubleMatrix2D)A1, 1.0, 0.0, true, false);
        for (int k = 0; k < K; ++k) {
            A1.setQuick(k, k, lambda + A1.getQuick(k, k));
        }
        EigenvalueDecomposition eig = new EigenvalueDecomposition((DoubleMatrix2D)A1);
        DoubleMatrix1D d = eig.getRealEigenvalues();
        DoubleMatrix2D gt = eig.getV();
        for (int k = 0; k < K; ++k) {
            double a = Math.sqrt(d.get(k));
            gt.viewColumn(k).assign(x -> a * x);
        }
        return gt;
    }

    private static <O> void prepareRR1(int L, DoubleMatrix1D w, DoubleMatrix2D gt, DoubleMatrix2D q, int N, Stream<? extends IdxPref> prefs, DoubleUnaryOperator confidence, double lambda) {
        int K = w.size();
        double[][] x = new double[K + N][K];
        double[] y = new double[K + N];
        double[] c = new double[K + N];
        for (int k = 0; k < K; ++k) {
            gt.viewColumn(k).toArray(x[k]);
            y[k] = 0.0;
            c[k] = 1.0;
        }
        int[] j = new int[]{K};
        prefs.forEach(iv -> {
            q.viewRow(iv.v1).toArray(x[j[0]]);
            double Cui = confidence.applyAsDouble(iv.v2);
            y[j[0]] = Cui * iv.v2 / (Cui - 1.0);
            c[j[0]] = Cui - 1.0;
            j[0] = j[0] + 1;
        });
        PZTFactorizer.doRR1(L, w, x, y, c, lambda);
    }

    private static void doRR1(int L, DoubleMatrix1D w, double[][] x, double[] y, double[] c, double lambda) {
        int N = x.length;
        int K = x[0].length;
        double[] e = new double[N];
        for (int i = 0; i < N; ++i) {
            double pred = 0.0;
            for (int k = 0; k < K; ++k) {
                pred += w.getQuick(k) * x[i][k];
            }
            e[i] = y[i] - pred;
        }
        for (int l = 0; l < L; ++l) {
            for (int k = 0; k < K; ++k) {
                int i;
                for (int i2 = 0; i2 < N; ++i2) {
                    int n = i2;
                    e[n] = e[n] + w.getQuick(k) * x[i2][k];
                }
                double a = 0.0;
                double d = 0.0;
                for (i = 0; i < N; ++i) {
                    a += c[i] * x[i][k] * x[i][k];
                    d += c[i] * x[i][k] * e[i];
                }
                w.setQuick(k, d / (lambda + a));
                for (i = 0; i < N; ++i) {
                    int n = i;
                    e[n] = e[n] - w.getQuick(k) * x[i][k];
                }
            }
        }
    }
}

