/*
 * Decompiled with CFR 0.152.
 */
package es.uam.eps.ir.ranksys.mf.plsa;

import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.jet.math.Functions;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import es.uam.eps.ir.ranksys.fast.preference.IdxPref;
import es.uam.eps.ir.ranksys.mf.Factorization;
import es.uam.eps.ir.ranksys.mf.Factorizer;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.ranksys.fast.preference.StreamsAbstractFastPreferenceData;

public class PLSAFactorizer<U, I>
extends Factorizer<U, I> {
    private static final Logger LOG = Logger.getLogger(PLSAFactorizer.class.getName());
    private final int numIter;

    public PLSAFactorizer(int numIter) {
        this.numIter = numIter;
    }

    @Override
    public double error(Factorization<U, I> factorization, FastPreferenceData<U, I> data) {
        DenseDoubleMatrix2D pu_z = factorization.getUserMatrix();
        DenseDoubleMatrix2D piz = factorization.getItemMatrix();
        return data.getUidxWithPreferences().parallel().mapToDouble(uidx -> {
            DoubleMatrix1D pU_z = pu_z.viewRow(uidx);
            DoubleMatrix1D pUi = piz.zMult(pU_z, null);
            return data.getUidxPreferences(uidx).mapToDouble(iv -> -iv.v2 * pUi.getQuick(iv.v1)).sum();
        }).sum();
    }

    @Override
    public Factorization<U, I> factorize(int K, FastPreferenceData<U, I> data) {
        DoubleFunction init = x -> Math.sqrt(1.0 / (double)K) * Math.random();
        Factorization factorization = new Factorization(data, data, K, init);
        this.factorize(factorization, data);
        return factorization;
    }

    @Override
    public void factorize(Factorization<U, I> factorization, FastPreferenceData<U, I> data) {
        DenseDoubleMatrix2D pu_z = factorization.getUserMatrix();
        DenseDoubleMatrix2D piz = factorization.getItemMatrix();
        IntOpenHashSet uidxs = new IntOpenHashSet(data.getUidxWithPreferences().toArray());
        IntStream.range(0, pu_z.rows()).filter(arg_0 -> PLSAFactorizer.lambda$factorize$17((IntSet)uidxs, arg_0)).forEach(uidx -> pu_z.viewRow(uidx).assign(0.0));
        IntOpenHashSet iidxs = new IntOpenHashSet(data.getIidxWithPreferences().toArray());
        IntStream.range(0, piz.rows()).filter(arg_0 -> PLSAFactorizer.lambda$factorize$19((IntSet)iidxs, arg_0)).forEach(iidx -> piz.viewRow(iidx).assign(0.0));
        PLSAPreferenceData<U, I> plsaData = new PLSAPreferenceData<U, I>(data, pu_z.columns());
        for (int z = 0; z < pu_z.columns(); ++z) {
            DoubleMatrix1D pu_Z = pu_z.viewColumn(z);
            pu_Z.assign(Functions.mult((double)(1.0 / pu_Z.aggregate(Functions.plus, Functions.identity))));
        }
        piz.assign(Functions.mult((double)(1.0 / piz.aggregate(Functions.plus, Functions.identity))));
        int t = 1;
        while (t <= this.numIter) {
            long time0 = System.nanoTime();
            this.expectation(pu_z, piz, plsaData);
            this.maximization(pu_z, piz, plsaData);
            int iter = t++;
            long time1 = System.nanoTime() - time0;
            LOG.log(Level.INFO, String.format("iteration n = %3d t = %.2fs", iter, (double)time1 / 1.0E9));
            LOG.log(Level.FINE, () -> String.format("iteration n = %3d e = %.6f", iter, this.error(factorization, data)));
        }
    }

    private void expectation(DenseDoubleMatrix2D pz_u, DenseDoubleMatrix2D piz, PLSAPreferenceData<U, I> qzData) {
        qzData.getUidxWithPreferences().parallel().forEach(uidx -> qzData.getUidxPreferences(uidx).forEach(iqz -> {
            int iidx = iqz.v1;
            double[] qz = ((PLSAPreferenceData.PLSAIdxPref)iqz).qz;
            for (int z = 0; z < qz.length; ++z) {
                qz[z] = piz.getQuick(iidx, z) * pz_u.getQuick(uidx, z);
            }
            PLSAFactorizer.normalizeQz(qz);
        }));
    }

    private void maximization(DenseDoubleMatrix2D pu_z, DenseDoubleMatrix2D piz, PLSAPreferenceData<U, I> qzData) {
        Int2ObjectOpenHashMap lockMap = new Int2ObjectOpenHashMap();
        qzData.getIidxWithPreferences().forEach(arg_0 -> PLSAFactorizer.lambda$maximization$24((Int2ObjectMap)lockMap, arg_0));
        pu_z.assign(0.0);
        piz.assign(0.0);
        qzData.getUidxWithPreferences().parallel().forEach(arg_0 -> PLSAFactorizer.lambda$maximization$26(pu_z, qzData, (Int2ObjectMap)lockMap, piz, arg_0));
        for (int z = 0; z < pu_z.columns(); ++z) {
            DoubleMatrix1D pZ_u = pu_z.viewColumn(z);
            pZ_u.assign(Functions.mult((double)(1.0 / pZ_u.aggregate(Functions.plus, Functions.identity))));
        }
        piz.assign(Functions.mult((double)(1.0 / piz.aggregate(Functions.plus, Functions.identity))));
    }

    private static void normalizeQz(double[] qz) {
        int i;
        double norm = 0.0;
        for (i = 0; i < qz.length; ++i) {
            norm += qz[i];
        }
        i = 0;
        while (i < qz.length) {
            int n = i++;
            qz[n] = qz[n] / norm;
        }
    }

    private static /* synthetic */ void lambda$maximization$26(DenseDoubleMatrix2D pu_z, PLSAPreferenceData qzData, Int2ObjectMap lockMap, DenseDoubleMatrix2D piz, int uidx) {
        DoubleMatrix1D pz_U = pu_z.viewRow(uidx);
        qzData.getUidxPreferences(uidx).forEach(iqz -> {
            double r;
            int z;
            int iidx = iqz.v1;
            double v = iqz.v2;
            double[] qz = ((PLSAPreferenceData.PLSAIdxPref)iqz).qz;
            Lock lock = (Lock)lockMap.get(iidx);
            for (z = 0; z < qz.length; ++z) {
                r = qz[z] * v;
                pz_U.setQuick(z, pz_U.getQuick(z) + r);
            }
            lock.lock();
            try {
                for (z = 0; z < qz.length; ++z) {
                    r = qz[z] * v;
                    piz.setQuick(iidx, z, piz.getQuick(iidx, z) + r);
                }
            }
            finally {
                lock.unlock();
            }
        });
    }

    private static /* synthetic */ void lambda$maximization$24(Int2ObjectMap lockMap, int iidx) {
        Lock cfr_ignored_0 = (Lock)lockMap.put(iidx, (Object)new ReentrantLock());
    }

    private static /* synthetic */ boolean lambda$factorize$19(IntSet iidxs, int iidx) {
        return !iidxs.contains(iidx);
    }

    private static /* synthetic */ boolean lambda$factorize$17(IntSet uidxs, int uidx) {
        return !uidxs.contains(uidx);
    }

    private static class PLSAPreferenceData<U, I>
    extends StreamsAbstractFastPreferenceData<U, I> {
        private final FastPreferenceData<U, I> data;
        private final Long2ObjectOpenHashMap<double[]> qz;

        public PLSAPreferenceData(FastPreferenceData<U, I> data, int K) {
            super(data, data);
            this.data = data;
            this.qz = new Long2ObjectOpenHashMap();
            data.getUidxWithPreferences().forEach(uidx -> data.getUidxPreferences(uidx).forEach(pref -> this.putQz(uidx, pref.v1, new double[K])));
        }

        private double[] getQz(int uidx, int iidx) {
            return (double[])this.qz.get((long)(uidx * this.data.numItems() + iidx));
        }

        private double[] putQz(int uidx, int iidx, double[] v) {
            return (double[])this.qz.put((long)(uidx * this.data.numItems() + iidx), (Object)v);
        }

        public int numUsers(int iidx) {
            return this.data.numUsers(iidx);
        }

        public int numItems(int uidx) {
            return this.data.numItems(uidx);
        }

        public IntStream getUidxWithPreferences() {
            return this.data.getUidxWithPreferences();
        }

        public IntStream getIidxWithPreferences() {
            return this.data.getIidxWithPreferences();
        }

        public Stream<IdxPref> getUidxPreferences(int uidx) {
            return this.data.getUidxPreferences(uidx).map(pref -> new PLSAIdxPref(pref.v1, pref.v2, this.getQz(uidx, pref.v1)));
        }

        public Stream<IdxPref> getIidxPreferences(int iidx) {
            return this.data.getIidxPreferences(iidx).map(pref -> new PLSAIdxPref(pref.v1, pref.v2, this.getQz(pref.v1, iidx)));
        }

        public int numPreferences() {
            return this.data.numPreferences();
        }

        public class PLSAIdxPref
        extends IdxPref {
            public double[] qz;

            public PLSAIdxPref(int idx, double value, double[] qz) {
                super(idx, value);
                this.qz = qz;
            }
        }
    }
}

