/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.CPT;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.SkeletonFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Matrix;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;
import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

public class Factors {
    public static CPT normalizeAsCpt(AbstractTableFactor ptl, Variable var) {
        Assignment assn;
        double[] sums = new double[ptl.numLocations()];
        Arrays.fill(sums, Double.NEGATIVE_INFINITY);
        HashVarSet neighbors = new HashVarSet(ptl.varSet());
        neighbors.remove(var);
        AssignmentIterator it = ptl.assignmentIterator();
        while (it.hasNext()) {
            assn = it.assignment();
            Assignment nbrAssn = (Assignment)assn.marginalizeOut(var);
            int idx = nbrAssn.singleIndex();
            sums[idx] = Maths.sumLogProb(ptl.logValue(assn), sums[idx]);
            it.advance();
        }
        it = ptl.assignmentIterator();
        while (it.hasNext()) {
            assn = it.assignment();
            double oldVal = ptl.logValue(assn);
            Assignment nbrAssn = (Assignment)assn.marginalizeOut(var);
            double logZ = sums[nbrAssn.singleIndex()];
            if (Double.isInfinite(oldVal) && Double.isInfinite(logZ)) {
                ptl.setLogValue(assn, Double.NEGATIVE_INFINITY);
            } else {
                ptl.setLogValue(assn, oldVal - logZ);
            }
            it.advance();
        }
        return new CPT(ptl, var);
    }

    public static Factor average(Factor ptl1, Factor ptl2, double weight) {
        TableFactor mptl1 = (TableFactor)ptl1;
        TableFactor mptl2 = (TableFactor)ptl2;
        return TableFactor.hackyMixture(mptl1, mptl2, weight);
    }

    public static double oneDistance(Factor bel1, Factor bel2) {
        VarSet vs2;
        VarSet vs1 = bel1.varSet();
        if (!vs1.equals(vs2 = bel2.varSet())) {
            throw new IllegalArgumentException("Attempt to take distancebetween mismatching potentials " + bel1 + " and " + bel2);
        }
        double dist = 0.0;
        AssignmentIterator it = bel1.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = it.assignment();
            dist += Math.abs(bel1.value(assn) - bel2.value(assn));
            it.advance();
        }
        return dist;
    }

    public static TableFactor retainMass(DiscreteFactor ptl, double alpha) {
        int[] idxs = new int[ptl.numLocations()];
        double[] vals = new double[ptl.numLocations()];
        int i = 0;
        while (i < idxs.length) {
            idxs[i] = ptl.indexAtLocation(i);
            vals[i] = ptl.logValue(i);
            ++i;
        }
        RankedFeatureVector rfv = new RankedFeatureVector(new Alphabet(), idxs, vals);
        TIntArrayList idxList = new TIntArrayList();
        TDoubleArrayList valList = new TDoubleArrayList();
        double mass = Double.NEGATIVE_INFINITY;
        double logAlpha = Math.log(alpha);
        int rank = 0;
        while (rank < rfv.numLocations()) {
            int idx = rfv.getIndexAtRank(rank);
            double val = rfv.value(idx);
            mass = Maths.sumLogProb(mass, val);
            idxList.add(idx);
            valList.add(val);
            if (mass > logAlpha) break;
            ++rank;
        }
        int[] szs = Factors.computeSizes(ptl);
        SparseMatrixn m = new SparseMatrixn(szs, idxList.toNativeArray(), valList.toNativeArray());
        TableFactor result = new TableFactor(Factors.computeVars(ptl));
        result.setValues(m);
        return result;
    }

    public static int[] computeSizes(Factor result) {
        int nv = result.varSet().size();
        int[] szs = new int[nv];
        int i = 0;
        while (i < nv) {
            Variable var = result.getVariable(i);
            szs[i] = var.getNumOutcomes();
            ++i;
        }
        return szs;
    }

    public static Variable[] computeVars(Factor result) {
        int nv = result.varSet().size();
        Variable[] vars = new Variable[nv];
        int i = 0;
        while (i < nv) {
            Variable var;
            vars[i] = var = result.getVariable(i);
            ++i;
        }
        return vars;
    }

    public static double mutualInformation(Factor factor) {
        VarSet vs = factor.varSet();
        if (vs.size() != 2) {
            throw new IllegalArgumentException("Factor must have size 2");
        }
        Factor marg1 = factor.marginalize(vs.get(0));
        Factor marg2 = factor.marginalize(vs.get(1));
        double result = 0.0;
        AssignmentIterator it = factor.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = (Assignment)it.next();
            result += factor.value(assn) * (factor.logValue(assn) - marg1.logValue(assn) - marg2.logValue(assn));
        }
        return result;
    }

    public static double KL(AbstractTableFactor f1, AbstractTableFactor f2) {
        double result = 0.0;
        int loc = 0;
        while (loc < f1.numLocations()) {
            double val1 = f1.valueAtLocation(loc);
            double val2 = f2.value(f1.indexAtLocation(loc));
            if (val1 > 1.0E-5) {
                result += val1 * Math.log(val1 / val2);
            }
            ++loc;
        }
        return result;
    }

    public static Factor mix(AbstractTableFactor f1, AbstractTableFactor f2, double alpha) {
        return AbstractTableFactor.hackyMixture(f1, f2, alpha);
    }

    public static double euclideanDistance(AbstractTableFactor f1, AbstractTableFactor f2) {
        double result = 0.0;
        int loc = 0;
        while (loc < f1.numLocations()) {
            double val1 = f1.valueAtLocation(loc);
            double val2 = f2.value(f1.indexAtLocation(loc));
            result += (val1 - val2) * (val1 - val2);
            ++loc;
        }
        return Math.sqrt(result);
    }

    public static double l1Distance(AbstractTableFactor f1, AbstractTableFactor f2) {
        double result = 0.0;
        int loc = 0;
        while (loc < f1.numLocations()) {
            double val1 = f1.valueAtLocation(loc);
            double val2 = f2.value(f1.indexAtLocation(loc));
            result += Math.abs(val1 - val2);
            ++loc;
        }
        return result;
    }

    public static Factor asFactor(final Inferencer inf) {
        return new SkeletonFactor(){

            @Override
            public double value(Assignment assn) {
                Factor factor = inf.lookupMarginal(assn.varSet());
                return factor.value(assn);
            }

            @Override
            public Factor marginalize(Variable[] vars) {
                return inf.lookupMarginal(new HashVarSet(vars));
            }

            @Override
            public Factor marginalize(Collection vars) {
                return inf.lookupMarginal(new HashVarSet(vars));
            }

            @Override
            public Factor marginalize(Variable var) {
                return inf.lookupMarginal(new HashVarSet(new Variable[]{var}));
            }

            @Override
            public Factor marginalizeOut(Variable var) {
                throw new UnsupportedOperationException();
            }

            @Override
            public Factor marginalizeOut(VarSet varset) {
                throw new UnsupportedOperationException();
            }

            @Override
            public VarSet varSet() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public static Variable[] discreteVarsOf(Factor fg) {
        ArrayList<Variable> vars = new ArrayList<Variable>();
        VarSet vs = fg.varSet();
        int vi = 0;
        while (vi < vs.size()) {
            Variable var = vs.get(vi);
            if (!var.isContinuous()) {
                vars.add(var);
            }
            ++vi;
        }
        return vars.toArray(new Variable[vars.size()]);
    }

    public static Variable[] continuousVarsOf(Factor fg) {
        ArrayList<Variable> vars = new ArrayList<Variable>();
        VarSet vs = fg.varSet();
        int vi = 0;
        while (vi < vs.size()) {
            Variable var = vs.get(vi);
            if (var.isContinuous()) {
                vars.add(var);
            }
            ++vi;
        }
        return vars.toArray(new Variable[vars.size()]);
    }

    public static double corr(Factor factor) {
        if (factor.varSet().size() != 2) {
            throw new IllegalArgumentException("corr() only works on Factors of size 2, tried " + factor);
        }
        Variable v0 = factor.varSet().get(0);
        Variable v1 = factor.varSet().get(1);
        double eXY = 0.0;
        AssignmentIterator it = factor.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = (Assignment)it.next();
            int val0 = assn.get(v0);
            int val1 = assn.get(v1);
            eXY += factor.value(assn) * (double)val0 * (double)val1;
        }
        double eX = Factors.mean(factor.marginalize(v0));
        double eY = Factors.mean(factor.marginalize(v1));
        return eXY - eX * eY;
    }

    private static double mean(Factor factor) {
        if (factor.varSet().size() != 1) {
            throw new IllegalArgumentException("mean() only works on Factors of size 1, tried " + factor);
        }
        Variable v0 = factor.varSet().get(0);
        double mean = 0.0;
        AssignmentIterator it = factor.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = (Assignment)it.next();
            int val0 = assn.get(v0);
            mean += factor.value(assn) * (double)val0;
        }
        return mean;
    }

    public static Factor multiplyAll(Collection factors) {
        Factor first = (Factor)factors.iterator().next();
        if (factors.size() == 1) {
            return first.duplicate();
        }
        HashVarSet vs = new HashVarSet();
        for (Factor phi : factors) {
            vs.addAll(phi.varSet());
        }
        Factor result = first.duplicate();
        for (Factor phi : factors) {
            result.multiplyBy(phi);
        }
        return result;
    }

    public static double distLinf(AbstractTableFactor f1, AbstractTableFactor f2) {
        Matrix m1 = f1.getLogValueMatrix();
        Matrix m2 = f2.getLogValueMatrix();
        return Factors.matrixDistLinf(m1, m2);
    }

    public static double distValueLinf(AbstractTableFactor f1, AbstractTableFactor f2) {
        Matrix m1 = f1.getValueMatrix();
        Matrix m2 = f2.getValueMatrix();
        return Factors.matrixDistLinf(m1, m2);
    }

    private static double matrixDistLinf(Matrix m1, Matrix m2) {
        int nl2;
        double max = 0.0;
        int nl1 = m1.singleSize();
        if (nl1 != (nl2 = m2.singleSize())) {
            return Double.POSITIVE_INFINITY;
        }
        int l = 0;
        while (l < nl1) {
            double val2;
            double val1 = m1.singleValue(l);
            double diff = val1 > (val2 = m2.singleValue(l)) ? val1 - val2 : val2 - val1;
            max = diff > max ? diff : max;
            ++l;
        }
        return max;
    }

    public static double logErrorRange(AbstractTableFactor f1, AbstractTableFactor f2) {
        int nl2;
        double error_min = Double.MAX_VALUE;
        double error_max = 0.0;
        Matrix m1 = f1.getLogValueMatrix();
        Matrix m2 = f2.getLogValueMatrix();
        int nl1 = m1.singleSize();
        if (nl1 != (nl2 = m2.singleSize())) {
            return Double.POSITIVE_INFINITY;
        }
        int l = 0;
        while (l < nl1) {
            double val2;
            double val1 = m1.singleValue(l);
            double diff = val1 > (val2 = m2.singleValue(l)) ? val1 - val2 : val2 - val1;
            error_max = diff > error_max ? diff : error_max;
            error_min = diff < error_min ? diff : error_min;
            ++l;
        }
        return error_max - error_min;
    }
}

