/*
 * Decompiled with CFR 0.152.
 */
package es.uam.eps.ir.ranksys.mf.als;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.LUDecompositionQuick;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import es.uam.eps.ir.ranksys.fast.preference.IdxPref;
import es.uam.eps.ir.ranksys.fast.preference.TransposedPreferenceData;
import es.uam.eps.ir.ranksys.mf.als.ALSFactorizer;
import java.util.function.DoubleUnaryOperator;

public class HKVFactorizer<U, I>
extends ALSFactorizer<U, I> {
    private static final Algebra ALG = new Algebra();
    private final double lambdaP;
    private final double lambdaQ;
    private final DoubleUnaryOperator confidence;

    public HKVFactorizer(double lambda, DoubleUnaryOperator confidence, int numIter) {
        this(lambda, lambda, confidence, numIter);
    }

    public HKVFactorizer(double lambdaP, double lambdaQ, DoubleUnaryOperator confidence, int numIter) {
        super(numIter);
        this.lambdaP = lambdaP;
        this.lambdaQ = lambdaQ;
        this.confidence = confidence;
    }

    @Override
    public double error(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, FastPreferenceData<U, I> data) {
        return data.getUidxWithPreferences().parallel().mapToDouble(uidx -> {
            DoubleMatrix1D pu = p.viewRow(uidx);
            DoubleMatrix1D su = q.zMult(pu, null);
            double err1 = data.getUidxPreferences(uidx).mapToDouble(iv -> {
                double rui = iv.v2;
                double sui = su.getQuick(iv.v1);
                double cui = this.confidence.applyAsDouble(rui);
                return cui * (rui - sui) * (rui - sui) - this.confidence.applyAsDouble(0.0) * sui * sui;
            }).sum();
            double err2 = this.confidence.applyAsDouble(0.0) * su.assign(x -> x * x).zSum();
            return (err1 + err2) / (double)data.numItems();
        }).sum() / (double)data.numUsers();
    }

    @Override
    public void set_minP(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, FastPreferenceData<U, I> data) {
        HKVFactorizer.set_min(p, q, this.confidence, this.lambdaP, data);
    }

    @Override
    public void set_minQ(DenseDoubleMatrix2D q, DenseDoubleMatrix2D p, FastPreferenceData<U, I> data) {
        HKVFactorizer.set_min(q, p, this.confidence, this.lambdaQ, new TransposedPreferenceData(data));
    }

    private static <U, I, O> void set_min(DenseDoubleMatrix2D p, DenseDoubleMatrix2D q, DoubleUnaryOperator confidence, double lambda, FastPreferenceData<U, I> data) {
        int K = p.columns();
        DenseDoubleMatrix2D A1P = new DenseDoubleMatrix2D(K, K);
        q.zMult((DoubleMatrix2D)q, (DoubleMatrix2D)A1P, 1.0, 0.0, true, false);
        for (int k = 0; k < K; ++k) {
            A1P.setQuick(k, k, lambda + A1P.getQuick(k, k));
        }
        DenseDoubleMatrix2D[] A2P = new DenseDoubleMatrix2D[q.rows()];
        data.getIidxWithPreferences().parallel().forEach(iidx -> {
            A2P[iidx] = new DenseDoubleMatrix2D(K, K);
            DoubleMatrix1D qi = q.viewRow(iidx);
            ALG.multOuter(qi, qi, (DoubleMatrix2D)A2P[iidx]);
        });
        data.getUidxWithPreferences().parallel().forEach(uidx -> {
            DenseDoubleMatrix2D A = new DenseDoubleMatrix2D(K, K);
            DenseDoubleMatrix1D b = new DenseDoubleMatrix1D(K);
            A.assign((DoubleMatrix2D)A1P);
            b.assign(0.0);
            data.getUidxPreferences(uidx).forEach(arg_0 -> HKVFactorizer.lambda$null$12(confidence, q, (DoubleMatrix2D)A, A2P, (DoubleMatrix1D)b, arg_0));
            LUDecompositionQuick lu = new LUDecompositionQuick(0.0);
            lu.decompose((DoubleMatrix2D)A);
            lu.solve((DoubleMatrix1D)b);
            p.viewRow(uidx).assign((DoubleMatrix1D)b);
        });
    }

    private static /* synthetic */ void lambda$null$12(DoubleUnaryOperator confidence, DenseDoubleMatrix2D q, DoubleMatrix2D A, DenseDoubleMatrix2D[] A2P, DoubleMatrix1D b, IdxPref iv) {
        int iidx = iv.v1;
        double rui = iv.v2;
        double cui = confidence.applyAsDouble(rui);
        DoubleMatrix1D qi = q.viewRow(iidx);
        A.assign((DoubleMatrix2D)A2P[iidx], (x, y) -> x + y * (cui - 1.0));
        b.assign(qi, (x, y) -> x + y * rui * cui);
    }
}

