/*
 * Decompiled with CFR 0.152.
 */
package es.uam.eps.ir.ranksys.mf.als;

import cern.colt.function.DoubleFunction;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import es.uam.eps.ir.ranksys.mf.Factorization;
import es.uam.eps.ir.ranksys.mf.Factorizer;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;

public abstract class ALSFactorizer<U, I>
extends Factorizer<U, I> {
    private static final Logger LOG = Logger.getLogger(ALSFactorizer.class.getName());
    private final int numIter;

    public ALSFactorizer(int numIter) {
        this.numIter = numIter;
    }

    @Override
    public double error(Factorization<U, I> factorization, FastPreferenceData<U, I> data) {
        DenseDoubleMatrix2D p = factorization.getUserMatrix();
        DenseDoubleMatrix2D q = factorization.getItemMatrix();
        return this.error(p, q, data);
    }

    @Override
    public Factorization<U, I> factorize(int K, FastPreferenceData<U, I> data) {
        DoubleFunction init = x -> Math.sqrt(1.0 / (double)K) * Math.random();
        Factorization factorization = new Factorization(data, data, K, init);
        this.factorize(factorization, data);
        return factorization;
    }

    @Override
    public void factorize(Factorization<U, I> factorization, FastPreferenceData<U, I> data) {
        DenseDoubleMatrix2D p = factorization.getUserMatrix();
        DenseDoubleMatrix2D q = factorization.getItemMatrix();
        IntOpenHashSet uidxs = new IntOpenHashSet(data.getUidxWithPreferences().toArray());
        IntStream.range(0, p.rows()).filter(arg_0 -> ALSFactorizer.lambda$factorize$1((IntSet)uidxs, arg_0)).forEach(uidx -> p.viewRow(uidx).assign(0.0));
        IntOpenHashSet iidxs = new IntOpenHashSet(data.getIidxWithPreferences().toArray());
        IntStream.range(0, q.rows()).filter(arg_0 -> ALSFactorizer.lambda$factorize$3((IntSet)iidxs, arg_0)).forEach(iidx -> q.viewRow(iidx).assign(0.0));
        int t = 1;
        while (t <= this.numIter) {
            long time0 = System.nanoTime();
            this.set_minQ(q, p, data);
            this.set_minP(p, q, data);
            int iter = t++;
            long time1 = System.nanoTime() - time0;
            LOG.log(Level.INFO, String.format("iteration n = %3d t = %.2fs", iter, (double)time1 / 1.0E9));
            LOG.log(Level.FINE, () -> String.format("iteration n = %3d e = %.6f", iter, this.error(factorization, data)));
        }
    }

    protected abstract double error(DenseDoubleMatrix2D var1, DenseDoubleMatrix2D var2, FastPreferenceData<U, I> var3);

    protected abstract void set_minP(DenseDoubleMatrix2D var1, DenseDoubleMatrix2D var2, FastPreferenceData<U, I> var3);

    protected abstract void set_minQ(DenseDoubleMatrix2D var1, DenseDoubleMatrix2D var2, FastPreferenceData<U, I> var3);

    private static /* synthetic */ boolean lambda$factorize$3(IntSet iidxs, int iidx) {
        return !iidxs.contains(iidx);
    }

    private static /* synthetic */ boolean lambda$factorize$1(IntSet uidxs, int uidx) {
        return !uidxs.contains(uidx);
    }
}

