/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreePropagation;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.BitVarSet;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Graphs;
import cc.mallet.types.Alphabet;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import org._3pq.jgrapht.GraphHelper;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;
import org._3pq.jgrapht.graph.ListenableUndirectedGraph;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;

public class JunctionTreeInferencer
extends AbstractInferencer {
    private static Logger logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName());
    private boolean inLogSpace;
    private JunctionTreePropagation propagator;
    protected transient JunctionTree jtCurrent;
    private transient ArrayList cliques;
    private static Comparator sepsetChooser = new Comparator(){

        public int compare(Object o1, Object o2) {
            int cost2;
            int cost1;
            int size2;
            if (o1 == o2) {
                return 0;
            }
            VarSet[] pair1 = (BitVarSet[])o1;
            VarSet[] pair2 = (BitVarSet[])o2;
            int size1 = JunctionTreeInferencer.sepsetSize((BitVarSet[])pair1);
            int retval = -JunctionTreeInferencer.cmp(size1, size2 = JunctionTreeInferencer.sepsetSize((BitVarSet[])pair2));
            if (retval == 0 && (retval = JunctionTreeInferencer.cmp(cost1 = JunctionTreeInferencer.sepsetCost(pair1), cost2 = JunctionTreeInferencer.sepsetCost(pair2))) == 0) {
                retval = JunctionTreeInferencer.cmp(o1.hashCode(), o2.hashCode());
            }
            return retval;
        }
    };
    private transient int totalMessagesSent = 0;
    private static final long serialVersionUID = 1L;

    public JunctionTreeInferencer() {
        this(JunctionTreePropagation.createSumProductInferencer());
    }

    public JunctionTreeInferencer(JunctionTreePropagation propagator) {
        this.propagator = propagator;
    }

    public static JunctionTreeInferencer createForMaxProduct() {
        return new JunctionTreeInferencer(JunctionTreePropagation.createMaxProductInferencer());
    }

    private boolean isAdjacent(UndirectedGraph g, Variable v1, Variable v2) {
        return g.getEdge(v1, v2) != null;
    }

    private int newEdgesRequired(UndirectedGraph mdl, Variable v) {
        int rating = 0;
        Iterator it1 = this.neighborsIterator(mdl, v);
        while (it1.hasNext()) {
            Variable neighbor1 = (Variable)it1.next();
            Iterator it2 = this.neighborsIterator(mdl, v);
            while (it2.hasNext()) {
                Variable neighbor2 = (Variable)it2.next();
                if (neighbor1 == neighbor2 || this.isAdjacent(mdl, neighbor1, neighbor2)) continue;
                ++rating;
            }
        }
        return rating;
    }

    private int weightRequired(UndirectedGraph mdl, Variable v) {
        int rating = 1;
        Iterator it1 = this.neighborsIterator(mdl, v);
        while (it1.hasNext()) {
            Variable neighbor = (Variable)it1.next();
            rating *= neighbor.getNumOutcomes();
        }
        return rating;
    }

    private void connectNeighbors(UndirectedGraph mdl, Variable v) {
        Iterator it1 = this.neighborsIterator(mdl, v);
        while (it1.hasNext()) {
            Variable neighbor1 = (Variable)it1.next();
            Iterator it2 = this.neighborsIterator(mdl, v);
            while (it2.hasNext()) {
                Variable neighbor2 = (Variable)it2.next();
                if (neighbor1 == neighbor2 || this.isAdjacent(mdl, neighbor1, neighbor2)) continue;
                try {
                    mdl.addEdge(neighbor1, neighbor2);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private boolean findSuperClique(List l, VarSet c) {
        for (VarSet c2 : l) {
            if (!c2.containsAll(c)) continue;
            return true;
        }
        return false;
    }

    private static int cmp(int i1, int i2) {
        if (i1 < i2) {
            return -1;
        }
        if (i1 > i2) {
            return 1;
        }
        return 0;
    }

    public Variable pickVertexToRemove(UndirectedGraph mdl, ArrayList lst) {
        Iterator it = lst.iterator();
        Variable best = (Variable)it.next();
        int bestVal1 = this.newEdgesRequired(mdl, best);
        int bestVal2 = this.weightRequired(mdl, best);
        while (it.hasNext()) {
            int val2;
            Variable v = (Variable)it.next();
            int val = this.newEdgesRequired(mdl, v);
            if (val < bestVal1) {
                best = v;
                bestVal1 = val;
                bestVal2 = this.weightRequired(mdl, v);
                continue;
            }
            if (val != bestVal1 || (val2 = this.weightRequired(mdl, v)) >= bestVal2) continue;
            best = v;
            bestVal1 = val;
            bestVal2 = val2;
        }
        return best;
    }

    private void triangulate(UndirectedGraph mdl) {
        UndirectedGraph mdl2 = this.dupGraph(mdl);
        ArrayList vars = new ArrayList(mdl.vertexSet());
        Alphabet varMap = this.makeVertexMap(vars);
        this.cliques = new ArrayList();
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Triangulating model: " + mdl);
            String ret = "";
            int i = 0;
            while (i < vars.size()) {
                Variable next = (Variable)vars.get(i);
                ret = String.valueOf(ret) + next.toString() + "\n";
                ++i;
            }
            logger.finer(ret);
        }
        while (!vars.isEmpty()) {
            Variable v = this.pickVertexToRemove(mdl2, vars);
            logger.finer("Triangulating vertex " + v);
            BitVarSet varSet = new BitVarSet(v.getUniverse(), GraphHelper.neighborListOf(mdl2, v));
            varSet.add(v);
            if (!this.findSuperClique(this.cliques, varSet)) {
                this.cliques.add(varSet);
                if (logger.isLoggable(Level.FINER)) {
                    logger.finer("  Elim clique " + varSet + " size " + varSet.size() + " weight " + varSet.weight());
                }
            }
            this.connectNeighbors(mdl2, v);
            vars.remove(v);
            mdl2.removeVertex(v);
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Triangulation done. Cliques are: ");
            int totSize = 0;
            int totWeight = 0;
            int maxSize = 0;
            int maxWeight = 0;
            for (VarSet c : this.cliques) {
                logger.finer(c.toString());
                totSize += c.size();
                maxSize = Math.max(c.size(), maxSize);
                totWeight += c.weight();
                maxWeight = Math.max(c.weight(), maxWeight);
            }
            double sz = this.cliques.size();
            logger.fine("Jt created " + sz + " cliques. Size: avg " + (double)totSize / sz + " max " + maxSize + " Weight: avg " + (double)totWeight / sz + " max " + maxWeight);
        }
    }

    private Alphabet makeVertexMap(ArrayList vars) {
        Alphabet map = new Alphabet(vars.size(), Variable.class);
        map.lookupIndices(vars.toArray(), true);
        return map;
    }

    private static int sepsetSize(BitVarSet[] pair) {
        assert (pair.length == 2);
        return pair[0].intersectionSize(pair[1]);
    }

    private static int sepsetCost(VarSet[] pair) {
        assert (pair.length == 2);
        return pair[0].weight() + pair[1].weight();
    }

    private JunctionTree graphToJt(UndirectedGraph g) {
        JunctionTree jt = new JunctionTree(g.vertexSet().size());
        Object root = g.vertexSet().iterator().next();
        jt.add(root);
        BreadthFirstIterator it1 = new BreadthFirstIterator(g, root);
        while (it1.hasNext()) {
            Object v1 = it1.next();
            for (Object v2 : GraphHelper.neighborListOf(g, v1)) {
                if (jt.getParent(v1) == v2) continue;
                jt.addNode(v1, v2);
            }
        }
        return jt;
    }

    private JunctionTree buildJtStructure() {
        TreeSet<BitVarSet[]> pq = new TreeSet<BitVarSet[]>(sepsetChooser);
        block0: for (BitVarSet c1 : this.cliques) {
            for (BitVarSet c2 : this.cliques) {
                if (c1 == c2) continue block0;
                pq.add(new BitVarSet[]{c1, c2});
            }
        }
        ListenableUndirectedGraph g = new ListenableUndirectedGraph(new SimpleGraph());
        for (VarSet c : this.cliques) {
            g.addVertex(c);
        }
        ConnectivityInspector inspector = new ConnectivityInspector(g);
        g.addGraphListener(inspector);
        int numCliques = this.cliques.size();
        int edgesAdded = 0;
        while (edgesAdded < numCliques - 1) {
            VarSet[] pair = (VarSet[])pq.first();
            pq.remove(pair);
            if (inspector.pathExists(pair[0], pair[1])) continue;
            g.addEdge(pair[0], pair[1]);
            ++edgesAdded;
        }
        JunctionTree jt = this.graphToJt(g);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("  jt structure was " + jt);
        }
        return jt;
    }

    private void initJtCpts(FactorGraph mdl, JunctionTree jt) {
        Iterator it = jt.getVerticesIterator();
        while (it.hasNext()) {
            VarSet c = (VarSet)it.next();
            jt.setCPF(c, new ConstantFactor(1.0));
        }
        for (Factor ptl : mdl.factors()) {
            VarSet parent = jt.findParentCluster(ptl.varSet());
            assert (parent != null) : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;
            Factor cpf = jt.getCPF(parent);
            Factor newCpf = cpf.multiply(ptl);
            jt.setCPF(parent, newCpf);
        }
    }

    private AbstractTableFactor createBlankFactor(VarSet c) {
        if (this.inLogSpace) {
            return new LogTableFactor(c);
        }
        return new TableFactor(c);
    }

    @Override
    public void computeMarginals(FactorGraph mdl) {
        this.inLogSpace = mdl.getFactor(0) instanceof LogTableFactor;
        this.buildJunctionTree(mdl);
        this.propagator.computeMarginals(this.jtCurrent);
        this.totalMessagesSent += this.propagator.getTotalMessagesSent();
    }

    public void computeMarginals(JunctionTree jt) {
        this.inLogSpace = false;
        this.jtCurrent = jt;
        this.propagator.computeMarginals(this.jtCurrent);
        this.totalMessagesSent += this.propagator.getTotalMessagesSent();
    }

    public JunctionTree buildJunctionTree(FactorGraph mdl) {
        this.jtCurrent = (JunctionTree)mdl.getInferenceCache(JunctionTreeInferencer.class);
        if (this.jtCurrent != null) {
            this.jtCurrent.clearCPFs();
        } else {
            UndirectedGraph g = Graphs.mdlToGraph(mdl);
            this.triangulate(g);
            this.jtCurrent = this.buildJtStructure();
            mdl.setInferenceCache(JunctionTreeInferencer.class, this.jtCurrent);
        }
        this.initJtCpts(mdl, this.jtCurrent);
        return this.jtCurrent;
    }

    private UndirectedGraph dupGraph(UndirectedGraph original) {
        SimpleGraph copy = new SimpleGraph();
        GraphHelper.addGraph(copy, original);
        return copy;
    }

    @Override
    public Factor lookupMarginal(Variable var) {
        return this.propagator.lookupMarginal(this.jtCurrent, var);
    }

    @Override
    public Factor lookupMarginal(VarSet varSet) {
        return this.propagator.lookupMarginal(this.jtCurrent, varSet);
    }

    @Override
    public double lookupLogJoint(Assignment assn) {
        return this.jtCurrent.lookupLogJoint(assn);
    }

    public double dumpLogJoint(Assignment assn) {
        return this.jtCurrent.dumpLogJoint(assn);
    }

    public JunctionTree lookupJunctionTree() {
        return this.jtCurrent;
    }

    private Iterator neighborsIterator(UndirectedGraph g, Variable v) {
        return GraphHelper.neighborListOf(g, v).iterator();
    }

    @Override
    public void dump() {
        if (this.jtCurrent != null) {
            System.out.println("Current junction tree");
            this.jtCurrent.dump();
        } else {
            System.out.println("NO current junction tree");
        }
    }

    public int getTotalMessagesSent() {
        return this.totalMessagesSent;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
    }
}

