/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.examples;

import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.CharSequenceLowercase;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;
import java.util.regex.Pattern;

public class TopicModel {
    public static void main(String[] args) throws Exception {
        IDSorter idCountPair;
        int rank;
        Iterator<IDSorter> iterator;
        ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
        pipeList.add(new CharSequenceLowercase());
        pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
        pipeList.add(new TokenSequenceRemoveStopwords(new File("stoplists/en.txt"), "UTF-8", false, false, false));
        pipeList.add(new TokenSequence2FeatureSequence());
        InstanceList instances = new InstanceList(new SerialPipes(pipeList));
        InputStreamReader fileReader = new InputStreamReader((InputStream)new FileInputStream(new File(args[0])), "UTF-8");
        instances.addThruPipe(new CsvIterator((Reader)fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1));
        int numTopics = 100;
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01);
        model.addInstances(instances);
        model.setNumThreads(2);
        model.setNumIterations(50);
        model.estimate();
        Alphabet dataAlphabet = instances.getDataAlphabet();
        FeatureSequence tokens = (FeatureSequence)model.getData().get((int)0).instance.getData();
        LabelSequence topics = model.getData().get((int)0).topicSequence;
        Formatter out = new Formatter(new StringBuilder(), Locale.US);
        int position = 0;
        while (position < tokens.getLength()) {
            out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
            ++position;
        }
        System.out.println(out);
        double[] topicDistribution = model.getTopicProbabilities(0);
        ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
        int topic = 0;
        while (topic < numTopics) {
            iterator = topicSortedWords.get(topic).iterator();
            out = new Formatter(new StringBuilder(), Locale.US);
            out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
            rank = 0;
            while (iterator.hasNext() && rank < 5) {
                idCountPair = iterator.next();
                out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
                ++rank;
            }
            System.out.println(out);
            ++topic;
        }
        StringBuilder topicZeroText = new StringBuilder();
        iterator = topicSortedWords.get(0).iterator();
        rank = 0;
        while (iterator.hasNext() && rank < 5) {
            idCountPair = iterator.next();
            topicZeroText.append(dataAlphabet.lookupObject(idCountPair.getID()) + " ");
            ++rank;
        }
        InstanceList testing = new InstanceList(instances.getPipe());
        testing.addThruPipe(new Instance(topicZeroText.toString(), null, "test instance", null));
        TopicInferencer inferencer = model.getInferencer();
        double[] testProbabilities = inferencer.getSampledDistribution((Instance)testing.get(0), 10, 1, 5);
        System.out.println("0\t" + testProbabilities[0]);
    }
}

