/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.logging.Logger;

public class HierarchicalPAM {
    protected static Logger logger = MalletLogger.getLogger(HierarchicalPAM.class.getName());
    static CommandOption.String inputFile = new CommandOption.String(HierarchicalPAM.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String stateFile = new CommandOption.String(HierarchicalPAM.class, "output-state", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.Double superTopicBalanceOption = new CommandOption.Double(HierarchicalPAM.class, "super-topic-balance", "DECIMAL", true, 1.0, "Weight (in \"words\") of the shared distribution over super-topics, relative to the document-specific distribution", null);
    static CommandOption.Double subTopicBalanceOption = new CommandOption.Double(HierarchicalPAM.class, "sub-topic-balance", "DECIMAL", true, 1.0, "Weight (in \"words\") of the shared distribution over sub-topics for each super-topic, relative to the document-specific distribution", null);
    static CommandOption.Integer numSuperTopicsOption = new CommandOption.Integer(HierarchicalPAM.class, "num-super-topics", "INTEGER", true, 10, "The number of super-topics", null);
    static CommandOption.Integer numSubTopicsOption = new CommandOption.Integer(HierarchicalPAM.class, "num-sub-topics", "INTEGER", true, 20, "The number of sub-topics", null);
    public static final int NUM_LEVELS = 3;
    public static final int ROOT_TOPIC = 0;
    public static final int SUPER_TOPIC = 1;
    public static final int SUB_TOPIC = 2;
    int numSuperTopics;
    int numSubTopics;
    double superTopicBalance = 1.0;
    double superTopicSmoothing = 1.0;
    double subTopicBalance = 1.0;
    double subTopicSmoothing = 1.0;
    double beta;
    double betaSum;
    InstanceList instances;
    int numTypes;
    int numTokens;
    int[][] superTopics;
    int[][] subTopics;
    int[][] superSubCounts;
    int[] superCounts;
    double[] superWeights;
    double[] subWeights;
    double[][] superSubWeights;
    double[] cumulativeSuperWeights;
    int[] superTopicDocumentFrequencies;
    int[][] superSubTopicDocumentFrequencies;
    int sumDocumentFrequencies;
    int[] sumSuperTopicDocumentFrequencies;
    double[] superTopicPriorWeights;
    double[][] superSubTopicPriorWeights;
    int[][] typeTopicCounts;
    int[] tokensPerTopic;
    int[] tokensPerSuperTopic;
    int[][] tokensPerSuperSubTopic;
    Runtime runtime;
    NumberFormat formatter = NumberFormat.getInstance();

    public HierarchicalPAM(int superTopics, int subTopics, double superTopicBalance, double subTopicBalance) {
        this.formatter.setMaximumFractionDigits(5);
        this.superTopicBalance = superTopicBalance;
        this.subTopicBalance = subTopicBalance;
        this.numSuperTopics = superTopics;
        this.numSubTopics = subTopics;
        this.superTopicDocumentFrequencies = new int[this.numSuperTopics + 1];
        this.superSubTopicDocumentFrequencies = new int[this.numSuperTopics + 1][this.numSubTopics + 1];
        this.sumSuperTopicDocumentFrequencies = new int[this.numSuperTopics];
        this.beta = 0.01;
        this.runtime = Runtime.getRuntime();
    }

    public void estimate(InstanceList documents, InstanceList testing, int numIterations, int showTopicsInterval, int outputModelInterval, int optimizeInterval, String outputModelFilename, Randoms r) {
        int superTopic;
        this.instances = documents;
        this.numTypes = this.instances.getDataAlphabet().size();
        int numDocs = this.instances.size();
        this.superTopics = new int[numDocs][];
        this.subTopics = new int[numDocs][];
        this.superSubCounts = new int[this.numSuperTopics + 1][this.numSubTopics + 1];
        this.superCounts = new int[this.numSuperTopics + 1];
        this.superWeights = new double[this.numSuperTopics + 1];
        this.subWeights = new double[this.numSubTopics];
        this.superSubWeights = new double[this.numSuperTopics + 1][this.numSubTopics + 1];
        this.cumulativeSuperWeights = new double[this.numSuperTopics];
        this.typeTopicCounts = new int[this.numTypes][1 + this.numSuperTopics + this.numSubTopics];
        this.tokensPerTopic = new int[1 + this.numSuperTopics + this.numSubTopics];
        this.tokensPerSuperTopic = new int[this.numSuperTopics + 1];
        this.tokensPerSuperSubTopic = new int[this.numSuperTopics + 1][this.numSubTopics + 1];
        this.betaSum = this.beta * (double)this.numTypes;
        long startTime = System.currentTimeMillis();
        int maxTokens = 0;
        int doc = 0;
        while (doc < numDocs) {
            int[] localTokensPerSuperTopic = new int[this.numSuperTopics + 1];
            int[][] localTokensPerSuperSubTopic = new int[this.numSuperTopics + 1][this.numSubTopics + 1];
            FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            this.numTokens += seqLen;
            this.superTopics[doc] = new int[seqLen];
            this.subTopics[doc] = new int[seqLen];
            int position = 0;
            while (position < seqLen) {
                superTopic = r.nextInt(this.numSuperTopics);
                int subTopic = r.nextInt(this.numSubTopics);
                int level = r.nextInt(3);
                if (level == 0) {
                    this.superTopics[doc][position] = this.numSuperTopics;
                    this.subTopics[doc][position] = this.numSubTopics;
                    int[] nArray = this.typeTopicCounts[fs.getIndexAtPosition(position)];
                    nArray[0] = nArray[0] + 1;
                    this.tokensPerTopic[0] = this.tokensPerTopic[0] + 1;
                    int n = this.numSuperTopics;
                    this.tokensPerSuperTopic[n] = this.tokensPerSuperTopic[n] + 1;
                    int[] nArray2 = this.tokensPerSuperSubTopic[this.numSuperTopics];
                    int n2 = this.numSubTopics;
                    nArray2[n2] = nArray2[n2] + 1;
                    if (localTokensPerSuperTopic[this.numSuperTopics] == 0) {
                        int n3 = this.numSuperTopics;
                        this.superTopicDocumentFrequencies[n3] = this.superTopicDocumentFrequencies[n3] + 1;
                        ++this.sumDocumentFrequencies;
                    }
                    int n4 = this.numSuperTopics;
                    localTokensPerSuperTopic[n4] = localTokensPerSuperTopic[n4] + 1;
                } else if (level == 1) {
                    this.superTopics[doc][position] = superTopic;
                    this.subTopics[doc][position] = this.numSubTopics;
                    int[] nArray = this.typeTopicCounts[fs.getIndexAtPosition(position)];
                    int n = 1 + superTopic;
                    nArray[n] = nArray[n] + 1;
                    int n5 = 1 + superTopic;
                    this.tokensPerTopic[n5] = this.tokensPerTopic[n5] + 1;
                    int n6 = superTopic;
                    this.tokensPerSuperTopic[n6] = this.tokensPerSuperTopic[n6] + 1;
                    int[] nArray3 = this.tokensPerSuperSubTopic[superTopic];
                    int n7 = this.numSubTopics;
                    nArray3[n7] = nArray3[n7] + 1;
                    if (localTokensPerSuperTopic[superTopic] == 0) {
                        int n8 = superTopic;
                        this.superTopicDocumentFrequencies[n8] = this.superTopicDocumentFrequencies[n8] + 1;
                        ++this.sumDocumentFrequencies;
                    }
                    int n9 = superTopic;
                    localTokensPerSuperTopic[n9] = localTokensPerSuperTopic[n9] + 1;
                    if (localTokensPerSuperSubTopic[superTopic][this.numSubTopics] == 0) {
                        int[] nArray4 = this.superSubTopicDocumentFrequencies[superTopic];
                        int n10 = this.numSubTopics;
                        nArray4[n10] = nArray4[n10] + 1;
                        int n11 = superTopic;
                        this.sumSuperTopicDocumentFrequencies[n11] = this.sumSuperTopicDocumentFrequencies[n11] + 1;
                    }
                    int[] nArray5 = localTokensPerSuperSubTopic[superTopic];
                    int n12 = this.numSubTopics;
                    nArray5[n12] = nArray5[n12] + 1;
                } else {
                    this.superTopics[doc][position] = superTopic;
                    this.subTopics[doc][position] = subTopic;
                    int[] nArray = this.typeTopicCounts[fs.getIndexAtPosition(position)];
                    int n = 1 + this.numSuperTopics + subTopic;
                    nArray[n] = nArray[n] + 1;
                    int n13 = 1 + this.numSuperTopics + subTopic;
                    this.tokensPerTopic[n13] = this.tokensPerTopic[n13] + 1;
                    int n14 = superTopic;
                    this.tokensPerSuperTopic[n14] = this.tokensPerSuperTopic[n14] + 1;
                    int[] nArray6 = this.tokensPerSuperSubTopic[superTopic];
                    int n15 = subTopic;
                    nArray6[n15] = nArray6[n15] + 1;
                    if (localTokensPerSuperTopic[superTopic] == 0) {
                        int n16 = superTopic;
                        this.superTopicDocumentFrequencies[n16] = this.superTopicDocumentFrequencies[n16] + 1;
                        ++this.sumDocumentFrequencies;
                    }
                    int n17 = superTopic;
                    localTokensPerSuperTopic[n17] = localTokensPerSuperTopic[n17] + 1;
                    if (localTokensPerSuperSubTopic[superTopic][subTopic] == 0) {
                        int[] nArray7 = this.superSubTopicDocumentFrequencies[superTopic];
                        int n18 = subTopic;
                        nArray7[n18] = nArray7[n18] + 1;
                        int n19 = superTopic;
                        this.sumSuperTopicDocumentFrequencies[n19] = this.sumSuperTopicDocumentFrequencies[n19] + 1;
                    }
                    int[] nArray8 = localTokensPerSuperSubTopic[superTopic];
                    int n20 = subTopic;
                    nArray8[n20] = nArray8[n20] + 1;
                }
                ++position;
            }
            ++doc;
        }
        this.superTopicPriorWeights = new double[this.numSuperTopics + 1];
        this.superSubTopicPriorWeights = new double[this.numSuperTopics][this.numSubTopics + 1];
        this.cacheSuperTopicPrior();
        superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            this.cacheSuperSubTopicPrior(superTopic);
            ++superTopic;
        }
        int iterations = 1;
        while (iterations < numIterations) {
            long iterationStart = System.currentTimeMillis();
            int doc2 = 0;
            while (doc2 < this.superTopics.length) {
                this.sampleTopicsForOneDoc((FeatureSequence)((Instance)this.instances.get(doc2)).getData(), this.superTopics[doc2], this.subTopics[doc2], r);
                ++doc2;
            }
            if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0) {
                logger.info(this.printTopWords(8, false));
            }
            logger.fine(String.valueOf(System.currentTimeMillis() - iterationStart) + " ");
            if (iterations % 10 == 0) {
                logger.info("<" + iterations + "> LL: " + this.formatter.format(this.modelLogLikelihood() / (double)this.numTokens));
            }
            ++iterations;
        }
    }

    private void cacheSuperTopicPrior() {
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            this.superTopicPriorWeights[superTopic] = ((double)this.superTopicDocumentFrequencies[superTopic] + this.superTopicSmoothing) / ((double)this.sumDocumentFrequencies + (double)(this.numSuperTopics + 1) * this.superTopicSmoothing);
            ++superTopic;
        }
        this.superTopicPriorWeights[this.numSuperTopics] = ((double)this.superTopicDocumentFrequencies[this.numSuperTopics] + this.superTopicSmoothing) / ((double)this.sumDocumentFrequencies + (double)(this.numSuperTopics + 1) * this.superTopicSmoothing);
    }

    private void cacheSuperSubTopicPrior(int superTopic) {
        int[] documentFrequencies = this.superSubTopicDocumentFrequencies[superTopic];
        int subTopic = 0;
        while (subTopic < this.numSubTopics) {
            this.superSubTopicPriorWeights[superTopic][subTopic] = ((double)documentFrequencies[subTopic] + this.subTopicSmoothing) / ((double)this.sumSuperTopicDocumentFrequencies[superTopic] + (double)(this.numSubTopics + 1) * this.subTopicSmoothing);
            ++subTopic;
        }
        this.superSubTopicPriorWeights[superTopic][this.numSubTopics] = ((double)documentFrequencies[this.numSubTopics] + this.subTopicSmoothing) / ((double)this.sumSuperTopicDocumentFrequencies[superTopic] + (double)(this.numSubTopics + 1) * this.subTopicSmoothing);
    }

    private void sampleTopicsForOneDoc(FeatureSequence oneDocTokens, int[] superTopics, int[] subTopics, Randoms r) {
        double[] wordWeights = new double[1 + this.numSuperTopics + this.numSubTopics];
        int docLen = oneDocTokens.getLength();
        Arrays.fill(this.superCounts, 0);
        int t = 0;
        while (t < this.numSuperTopics) {
            Arrays.fill(this.superSubCounts[t], 0);
            ++t;
        }
        int position = 0;
        while (position < docLen) {
            int[] nArray = this.superSubCounts[superTopics[position]];
            int n = subTopics[position];
            nArray[n] = nArray[n] + 1;
            int n2 = superTopics[position];
            this.superCounts[n2] = this.superCounts[n2] + 1;
            ++position;
        }
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            this.superWeights[superTopic] = ((double)this.superCounts[superTopic] + this.superTopicBalance * this.superTopicPriorWeights[superTopic]) / ((double)this.superCounts[superTopic] + this.subTopicBalance);
            assert (this.superWeights[superTopic] != 0.0);
            ++superTopic;
        }
        position = 0;
        while (position < docLen) {
            double[] currentSuperSubWeights;
            int type = oneDocTokens.getIndexAtPosition(position);
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            superTopic = superTopics[position];
            int subTopic = subTopics[position];
            if (superTopic == this.numSuperTopics) {
                currentTypeTopicCounts[0] = currentTypeTopicCounts[0] - 1;
                this.tokensPerTopic[0] = this.tokensPerTopic[0] - 1;
            } else if (subTopic == this.numSubTopics) {
                int n = 1 + superTopic;
                currentTypeTopicCounts[n] = currentTypeTopicCounts[n] - 1;
                int n3 = 1 + superTopic;
                this.tokensPerTopic[n3] = this.tokensPerTopic[n3] - 1;
            } else {
                int n = 1 + this.numSuperTopics + subTopic;
                currentTypeTopicCounts[n] = currentTypeTopicCounts[n] - 1;
                int n4 = 1 + this.numSuperTopics + subTopic;
                this.tokensPerTopic[n4] = this.tokensPerTopic[n4] - 1;
            }
            int n = superTopic;
            this.superCounts[n] = this.superCounts[n] - 1;
            int[] nArray = this.superSubCounts[superTopic];
            int n5 = subTopic;
            nArray[n5] = nArray[n5] - 1;
            if (this.superCounts[superTopic] == 0) {
                int n6 = superTopic;
                this.superTopicDocumentFrequencies[n6] = this.superTopicDocumentFrequencies[n6] - 1;
                --this.sumDocumentFrequencies;
                this.cacheSuperTopicPrior();
            }
            if (superTopic != this.numSuperTopics && this.superSubCounts[superTopic][subTopic] == 0) {
                int[] nArray2 = this.superSubTopicDocumentFrequencies[superTopic];
                int n7 = subTopic;
                nArray2[n7] = nArray2[n7] - 1;
                int n8 = superTopic;
                this.sumSuperTopicDocumentFrequencies[n8] = this.sumSuperTopicDocumentFrequencies[n8] - 1;
                this.cacheSuperSubTopicPrior(superTopic);
            }
            int n9 = superTopic;
            this.tokensPerSuperTopic[n9] = this.tokensPerSuperTopic[n9] - 1;
            int[] nArray3 = this.tokensPerSuperSubTopic[superTopic];
            int n10 = subTopic;
            nArray3[n10] = nArray3[n10] - 1;
            this.superWeights[superTopic] = ((double)this.superCounts[superTopic] + this.superTopicBalance * this.superTopicPriorWeights[superTopic]) / ((double)this.superCounts[superTopic] + this.subTopicBalance);
            int i = 0;
            while (i < wordWeights.length) {
                wordWeights[i] = (this.beta + (double)currentTypeTopicCounts[i]) / (this.betaSum + (double)this.tokensPerTopic[i]);
                assert (wordWeights[i] != 0.0);
                ++i;
            }
            Arrays.fill(this.cumulativeSuperWeights, 0.0);
            double cumulativeWeight = 0.0;
            superTopic = 0;
            while (superTopic < this.numSuperTopics) {
                currentSuperSubWeights = this.superSubWeights[superTopic];
                int[] currentSuperSubCounts = this.superSubCounts[superTopic];
                double currentSuperWeight = this.superWeights[superTopic];
                int[] documentFrequencies = this.superSubTopicDocumentFrequencies[superTopic];
                double[] priorCache = this.superSubTopicPriorWeights[superTopic];
                subTopic = 0;
                while (subTopic < this.numSubTopics) {
                    currentSuperSubWeights[subTopic] = currentSuperWeight * wordWeights[1 + this.numSuperTopics + subTopic] * ((double)currentSuperSubCounts[subTopic] + this.subTopicBalance * priorCache[subTopic]);
                    cumulativeWeight += currentSuperSubWeights[subTopic];
                    ++subTopic;
                }
                currentSuperSubWeights[this.numSubTopics] = currentSuperWeight * wordWeights[1 + superTopic] * ((double)currentSuperSubCounts[this.numSubTopics] + this.subTopicBalance * priorCache[this.numSubTopics]);
                this.cumulativeSuperWeights[superTopic] = cumulativeWeight += currentSuperSubWeights[this.numSubTopics];
                assert (this.cumulativeSuperWeights[superTopic] != 0.0);
                ++superTopic;
            }
            double rootWeight = wordWeights[0] * ((double)this.superCounts[this.numSuperTopics] + this.superTopicBalance * this.superTopicPriorWeights[this.numSuperTopics]);
            double sample = r.nextUniform() * (cumulativeWeight + rootWeight);
            if (sample > cumulativeWeight) {
                currentTypeTopicCounts[0] = currentTypeTopicCounts[0] + 1;
                this.tokensPerTopic[0] = this.tokensPerTopic[0] + 1;
                superTopic = this.numSuperTopics;
                subTopic = this.numSubTopics;
            } else {
                superTopic = 0;
                while (sample > this.cumulativeSuperWeights[superTopic]) {
                    ++superTopic;
                }
                currentSuperSubWeights = this.superSubWeights[superTopic];
                cumulativeWeight = this.cumulativeSuperWeights[superTopic];
                subTopic = 0;
                cumulativeWeight -= currentSuperSubWeights[0];
                while (sample < cumulativeWeight) {
                    cumulativeWeight -= currentSuperSubWeights[++subTopic];
                }
                if (subTopic == this.numSubTopics) {
                    int n11 = 1 + superTopic;
                    currentTypeTopicCounts[n11] = currentTypeTopicCounts[n11] + 1;
                    int n12 = 1 + superTopic;
                    this.tokensPerTopic[n12] = this.tokensPerTopic[n12] + 1;
                } else {
                    int n13 = 1 + this.numSuperTopics + subTopic;
                    currentTypeTopicCounts[n13] = currentTypeTopicCounts[n13] + 1;
                    int n14 = 1 + this.numSuperTopics + subTopic;
                    this.tokensPerTopic[n14] = this.tokensPerTopic[n14] + 1;
                }
            }
            superTopics[position] = superTopic;
            subTopics[position] = subTopic;
            int[] nArray4 = this.superSubCounts[superTopic];
            int n15 = subTopic;
            nArray4[n15] = nArray4[n15] + 1;
            int n16 = superTopic;
            this.superCounts[n16] = this.superCounts[n16] + 1;
            if (this.superCounts[superTopic] == 1) {
                int n17 = superTopic;
                this.superTopicDocumentFrequencies[n17] = this.superTopicDocumentFrequencies[n17] + 1;
                ++this.sumDocumentFrequencies;
                this.cacheSuperTopicPrior();
            }
            if (superTopic != this.numSuperTopics && this.superSubCounts[superTopic][subTopic] == 1) {
                int[] nArray5 = this.superSubTopicDocumentFrequencies[superTopic];
                int n18 = subTopic;
                nArray5[n18] = nArray5[n18] + 1;
                int n19 = superTopic;
                this.sumSuperTopicDocumentFrequencies[n19] = this.sumSuperTopicDocumentFrequencies[n19] + 1;
                this.cacheSuperSubTopicPrior(superTopic);
            }
            int n20 = superTopic;
            this.tokensPerSuperTopic[n20] = this.tokensPerSuperTopic[n20] + 1;
            int[] nArray6 = this.tokensPerSuperSubTopic[superTopic];
            int n21 = subTopic;
            nArray6[n21] = nArray6[n21] + 1;
            this.superWeights[superTopic] = ((double)this.superCounts[superTopic] + this.superTopicBalance * this.superTopicPriorWeights[superTopic]) / ((double)this.superCounts[superTopic] + this.subTopicBalance);
            ++position;
        }
    }

    public String printTopWords(int numWords, boolean useNewLines) {
        StringBuilder output = new StringBuilder();
        Object[] sortedTypes = new IDSorter[this.numTypes];
        Object[] sortedSubTopics = new IDSorter[this.numSubTopics];
        String[] topicTerms = new String[1 + this.numSuperTopics + this.numSubTopics];
        int topic = 0;
        while (topic < topicTerms.length) {
            int type = 0;
            while (type < this.numTypes) {
                sortedTypes[type] = new IDSorter(type, (double)this.typeTopicCounts[type][topic] / (double)this.tokensPerTopic[topic]);
                ++type;
            }
            Arrays.sort(sortedTypes);
            StringBuilder terms = new StringBuilder();
            int i = 0;
            while (i < numWords) {
                terms.append(this.instances.getDataAlphabet().lookupObject(((IDSorter)sortedTypes[i]).getID()));
                terms.append(" ");
                ++i;
            }
            topicTerms[topic] = terms.toString();
            ++topic;
        }
        int maxSubTopics = 10;
        if (this.numSubTopics < 10) {
            maxSubTopics = this.numSubTopics;
        }
        output.append("Root: [" + this.tokensPerSuperTopic[this.numSuperTopics] + "/" + this.superTopicDocumentFrequencies[this.numSuperTopics] + "]" + topicTerms[0] + "\n");
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            int subTopic = 0;
            while (subTopic < this.numSubTopics) {
                sortedSubTopics[subTopic] = new IDSorter(subTopic, this.tokensPerSuperSubTopic[superTopic][subTopic]);
                ++subTopic;
            }
            Arrays.sort(sortedSubTopics);
            output.append("\nSuper-topic " + superTopic + " [" + this.tokensPerSuperTopic[superTopic] + "/" + this.superTopicDocumentFrequencies[superTopic] + " " + this.tokensPerSuperSubTopic[superTopic][this.numSubTopics] + "/" + this.superSubTopicDocumentFrequencies[superTopic][this.numSubTopics] + "]\t" + topicTerms[1 + superTopic] + "\n");
            int i = 0;
            while (i < maxSubTopics) {
                subTopic = ((IDSorter)sortedSubTopics[i]).getID();
                output.append(String.valueOf(subTopic) + ":\t" + this.tokensPerSuperSubTopic[superTopic][subTopic] + "/" + this.formatter.format(this.superSubTopicDocumentFrequencies[superTopic][subTopic]) + "\t" + topicTerms[1 + this.numSuperTopics + subTopic] + "\n");
                ++i;
            }
            ++superTopic;
        }
        return output.toString();
    }

    public void printState(File f) throws IOException {
        PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter(f)));
        this.printState(out);
        out.close();
    }

    public void printState(PrintWriter out) {
        Alphabet alphabet = this.instances.getDataAlphabet();
        out.println("#doc pos typeindex type super-topic sub-topic");
        int doc = 0;
        while (doc < this.superTopics.length) {
            StringBuilder output = new StringBuilder();
            FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
            int position = 0;
            while (position < this.superTopics[doc].length) {
                int type = fs.getIndexAtPosition(position);
                output.append(doc);
                output.append(' ');
                output.append(position);
                output.append(' ');
                output.append(type);
                output.append(' ');
                output.append(alphabet.lookupObject(type));
                output.append(' ');
                output.append(this.superTopics[doc][position]);
                output.append(' ');
                output.append(this.subTopics[doc][position]);
                output.append("\n");
                ++position;
            }
            out.print(output);
            ++doc;
        }
    }

    public double modelLogLikelihood() {
        int subTopic;
        double logLikelihood = 0.0;
        double[] superTopicLogGammas = new double[this.numSuperTopics + 1];
        double[][] superSubTopicLogGammas = new double[this.numSuperTopics][this.numSubTopics + 1];
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            superTopicLogGammas[superTopic] = Dirichlet.logGamma(this.superTopicPriorWeights[superTopic]);
            subTopic = 0;
            while (subTopic < this.numSubTopics) {
                superSubTopicLogGammas[superTopic][subTopic] = Dirichlet.logGamma(this.subTopicBalance * this.superSubTopicPriorWeights[superTopic][subTopic]);
                ++subTopic;
            }
            superSubTopicLogGammas[superTopic][this.numSubTopics] = Dirichlet.logGamma(this.subTopicBalance * this.superSubTopicPriorWeights[superTopic][this.numSubTopics]);
            ++superTopic;
        }
        superTopicLogGammas[this.numSuperTopics] = Dirichlet.logGamma(this.superTopicPriorWeights[this.numSuperTopics]);
        int[] superTopicCounts = new int[this.numSuperTopics + 1];
        int[][] superSubTopicCounts = new int[this.numSuperTopics][this.numSubTopics + 1];
        int doc = 0;
        while (doc < this.superTopics.length) {
            int[] docSuperTopics = this.superTopics[doc];
            int[] docSubTopics = this.subTopics[doc];
            int token = 0;
            while (token < docSuperTopics.length) {
                superTopic = docSuperTopics[token];
                subTopic = docSubTopics[token];
                int n = superTopic;
                superTopicCounts[n] = superTopicCounts[n] + 1;
                if (superTopic != this.numSuperTopics) {
                    int[] nArray = superSubTopicCounts[superTopic];
                    int n2 = subTopic;
                    nArray[n2] = nArray[n2] + 1;
                }
                ++token;
            }
            superTopic = 0;
            while (superTopic < this.numSuperTopics) {
                if (superTopicCounts[superTopic] > 0) {
                    logLikelihood += Dirichlet.logGamma(this.superTopicBalance * this.superTopicPriorWeights[superTopic] + (double)superTopicCounts[superTopic]) - superTopicLogGammas[superTopic];
                    subTopic = 0;
                    while (subTopic < this.numSubTopics) {
                        if (superSubTopicCounts[superTopic][subTopic] > 0) {
                            logLikelihood += Dirichlet.logGamma(this.subTopicBalance * this.superSubTopicPriorWeights[superTopic][subTopic] + (double)superSubTopicCounts[superTopic][subTopic]) - superSubTopicLogGammas[superTopic][subTopic];
                        }
                        ++subTopic;
                    }
                    logLikelihood += Dirichlet.logGamma(this.subTopicBalance * this.superSubTopicPriorWeights[superTopic][this.numSubTopics] + (double)superSubTopicCounts[superTopic][this.numSubTopics]) - superSubTopicLogGammas[superTopic][this.numSubTopics];
                    logLikelihood += Dirichlet.logGamma(this.subTopicBalance) - Dirichlet.logGamma(this.subTopicBalance + (double)superTopicCounts[superTopic]);
                    Arrays.fill(superSubTopicCounts[superTopic], 0);
                }
                ++superTopic;
            }
            logLikelihood += Dirichlet.logGamma(this.superTopicBalance * this.superTopicPriorWeights[this.numSuperTopics] + (double)superTopicCounts[this.numSuperTopics]) - superTopicLogGammas[this.numSuperTopics];
            logLikelihood -= Dirichlet.logGamma(this.superTopicBalance + (double)docSuperTopics.length);
            Arrays.fill(superTopicCounts, 0);
            ++doc;
        }
        logLikelihood += (double)this.superTopics.length * Dirichlet.logGamma(this.superTopicBalance);
        int nonZeroTypeTopics = 0;
        int type = 0;
        while (type < this.numTypes) {
            int[] topicCounts = this.typeTopicCounts[type];
            int topic = 0;
            while (topic < this.numSuperTopics + this.numSubTopics + 1) {
                if (topicCounts[topic] > 0) {
                    ++nonZeroTypeTopics;
                    logLikelihood += Dirichlet.logGamma(this.beta + (double)topicCounts[topic]);
                }
                ++topic;
            }
            ++type;
        }
        int topic = 0;
        while (topic < this.numSuperTopics + this.numSubTopics + 1) {
            logLikelihood -= Dirichlet.logGamma(this.beta * (double)(this.numSuperTopics + this.numSubTopics + 1) + (double)this.tokensPerTopic[topic]);
            ++topic;
        }
        return logLikelihood += Dirichlet.logGamma(this.beta * (double)(this.numSuperTopics + this.numSubTopics + 1)) - Dirichlet.logGamma(this.beta) * (double)nonZeroTypeTopics;
    }

    public static void main(String[] args) throws IOException {
        CommandOption.setSummary(HierarchicalPAM.class, "Train a three level hierarchy of topics");
        CommandOption.process(HierarchicalPAM.class, args);
        InstanceList instances = InstanceList.load(new File(HierarchicalPAM.inputFile.value));
        InstanceList testing = null;
        HierarchicalPAM pam = new HierarchicalPAM(HierarchicalPAM.numSuperTopicsOption.value, HierarchicalPAM.numSubTopicsOption.value, HierarchicalPAM.superTopicBalanceOption.value, HierarchicalPAM.subTopicBalanceOption.value);
        pam.estimate(instances, testing, 1000, 100, 0, 250, null, new Randoms());
        if (stateFile.wasInvoked()) {
            pam.printState(new File(HierarchicalPAM.stateFile.value));
        }
    }
}

