/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.MarginalProbEstimator;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetFactory;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import gnu.trove.TIntDoubleHashMap;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.zip.GZIPOutputStream;

public class WeightedTopicModel
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(WeightedTopicModel.class.getName());
    static CommandOption.String inputFile = new CommandOption.String(WeightedTopicModel.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String weightsFile = new CommandOption.String(WeightedTopicModel.class, "weights-filename", "FILENAME", true, null, "The filename for the word-word weights file.", null);
    static CommandOption.String evaluatorFilename = new CommandOption.String(WeightedTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String stateFile = new CommandOption.String(WeightedTopicModel.class, "state-filename", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.Integer numTopicsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null);
    static CommandOption.Integer numEpochsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-epochs", "INTEGER", true, 1, "The number of cycles of training. Evaluators and state files will be saved after each epoch.", null);
    static CommandOption.Integer numIterationsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling PER EPOCH.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(WeightedTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
    static CommandOption.Double alphaOption = new CommandOption.Double(WeightedTopicModel.class, "alpha", "DECIMAL", true, 50.0, "Alpha parameter: smoothing over topic distribution.", null);
    static CommandOption.Double betaOption = new CommandOption.Double(WeightedTopicModel.class, "beta", "DECIMAL", true, 0.01, "Beta parameter: smoothing over topic distribution.", null);
    public static Pattern sourceWordPattern = Pattern.compile("(.*) \\((\\d+)\\)");
    public static Pattern targetWordPattern = Pattern.compile("  (\\d+)\t(\\d+)\t([\\d\\.]+)\t(.*)");
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected int[] oneDocTopicCounts;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected TIntDoubleHashMap[] typeTypeWeights;
    protected double[][] logTypeTopicWeights;
    protected double[][] typeTopicWeights;
    protected double[] totalTopicWeights;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    protected double[] logCountRatioCache;

    public WeightedTopicModel(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(numberOfTopics);
        this.numTopics = this.topicAlphabet.size();
        this.alphaSum = alphaSum;
        this.alpha = alphaSum / (double)this.numTopics;
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Weighted LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTopicTotals() {
        return this.tokensPerTopic;
    }

    public void addInstances(InstanceList training) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        this.typeTopicWeights = new double[this.numTypes][this.numTopics];
        this.totalTopicWeights = new double[this.numTopics];
        int type = 0;
        while (type < this.numTypes) {
            Arrays.fill(this.typeTopicWeights[type], this.beta);
            ++type;
        }
        Arrays.fill(this.totalTopicWeights, this.betaSum);
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            FeatureSequence tokenSequence = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokenSequence.size()]);
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
    }

    public void readTypeTypeWeights(File weightsFile) throws Exception {
        String line;
        this.typeTypeWeights = new TIntDoubleHashMap[this.numTypes];
        logger.info("num types: " + this.numTypes);
        int type = 0;
        while (type < this.numTypes) {
            this.typeTypeWeights[type] = new TIntDoubleHashMap();
            this.typeTypeWeights[type].put(type, 1.0);
            ++type;
        }
        int sourceType = 0;
        boolean sourceWordValid = true;
        BufferedReader reader = new BufferedReader(new FileReader(weightsFile));
        while ((line = reader.readLine()) != null) {
            String[] fields = line.split("\t");
            double sum = 0.0;
            int i = 1;
            while (i < fields.length) {
                sum += Double.parseDouble(fields[i]);
                i += 2;
            }
            sourceType = this.alphabet.lookupIndex(fields[0]);
            this.typeTypeWeights[sourceType].put(sourceType, Double.parseDouble(fields[1]) / sum);
            i = 2;
            while (i < fields.length) {
                int targetType = this.alphabet.lookupIndex(fields[i]);
                this.typeTypeWeights[sourceType].put(targetType, Double.parseDouble(fields[i + 1]) / sum);
                i += 2;
            }
        }
    }

    public void sample(int iterations, boolean shouldInitialize, int docCycleCount) throws IOException {
        int iteration = 1;
        while (iteration <= iterations) {
            long iterationStart = System.currentTimeMillis();
            int doc = 0;
            while (doc < this.data.size()) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, shouldInitialize && iteration == 1, false);
                int i = 1;
                while (i < docCycleCount) {
                    this.sampleTopicsForOneDoc(tokenSequence, topicSequence, false, false);
                    ++i;
                }
                ++doc;
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            logger.info(String.valueOf(iteration) + "\t" + elapsedMillis + "ms\t");
            if (this.showTopicsInterval != 0 && iteration % this.showTopicsInterval == 0) {
                logger.info("<" + iteration + ">\n" + this.topWords(this.wordsPerTopic));
            }
            ++iteration;
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean initializing, boolean debugging) {
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        int[] localTopicCounts = new int[this.numTopics];
        if (!initializing) {
            int position = 0;
            while (position < docLength) {
                int n = oneDocTopics[position];
                localTopicCounts[n] = localTopicCounts[n] + 1;
                ++position;
            }
        }
        double[] topicTermScores = new double[this.numTopics];
        int position = 0;
        while (position < docLength) {
            int type = tokenSequence.getIndexAtPosition(position);
            int oldTopic = oneDocTopics[position];
            TIntDoubleHashMap typeFactors = this.typeTypeWeights[type];
            int[] connectedTypes = typeFactors.keys();
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            double[] currentTypeTopicWeights = this.typeTopicWeights[type];
            if (!initializing) {
                int n = oldTopic;
                localTopicCounts[n] = localTopicCounts[n] - 1;
                int n2 = oldTopic;
                this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
                assert (this.tokensPerTopic[oldTopic] >= 0);
                int n3 = oldTopic;
                currentTypeTopicCounts[n3] = currentTypeTopicCounts[n3] - 1;
                int typeCount = currentTypeTopicCounts[oldTopic];
                int[] nArray = connectedTypes;
                int n4 = connectedTypes.length;
                int n5 = 0;
                while (n5 < n4) {
                    int otherType = nArray[n5];
                    double factor = typeFactors.get(otherType);
                    double[] dArray = this.typeTopicWeights[otherType];
                    int n6 = oldTopic;
                    dArray[n6] = dArray[n6] - factor;
                    int n7 = oldTopic;
                    this.totalTopicWeights[n7] = this.totalTopicWeights[n7] - factor;
                    ++n5;
                }
            }
            double sum = 0.0;
            int topic = 0;
            while (topic < this.numTopics) {
                double score = (this.alpha + (double)localTopicCounts[topic]) * (currentTypeTopicWeights[topic] / this.totalTopicWeights[topic]);
                sum += score;
                topicTermScores[topic] = score;
                if (debugging && type == 68) {
                    System.out.println(String.valueOf(type) + "\t" + topic + "\t" + localTopicCounts[topic] + "\t" + currentTypeTopicCounts[topic] + "\t" + currentTypeTopicWeights[topic] + "\t" + this.tokensPerTopic[topic] + "\t" + sum);
                }
                ++topic;
            }
            double sample = this.random.nextUniform() * sum;
            if (debugging) {
                System.out.println("sample " + sample + " / " + sum);
            }
            int newTopic = -1;
            while (sample > 0.0) {
                sample -= topicTermScores[++newTopic];
            }
            if (!debugging) {
                // empty if block
            }
            oneDocTopics[position] = newTopic;
            int n = newTopic;
            localTopicCounts[n] = localTopicCounts[n] + 1;
            int n8 = newTopic;
            this.tokensPerTopic[n8] = this.tokensPerTopic[n8] + 1;
            int n9 = newTopic;
            currentTypeTopicCounts[n9] = currentTypeTopicCounts[n9] + 1;
            int typeCount = currentTypeTopicCounts[newTopic];
            int[] nArray = connectedTypes;
            int n10 = connectedTypes.length;
            int n11 = 0;
            while (n11 < n10) {
                int otherType = nArray[n11];
                double factor = typeFactors.get(otherType);
                double[] dArray = this.typeTopicWeights[otherType];
                int n12 = newTopic;
                dArray[n12] = dArray[n12] + factor;
                int n13 = newTopic;
                this.totalTopicWeights[n13] = this.totalTopicWeights[n13] + factor;
                ++n11;
            }
            ++position;
        }
    }

    public String topWords(int numWords) {
        StringBuilder output = new StringBuilder();
        Object[] sortedWords = new IDSorter[this.numTypes];
        int topic = 0;
        while (topic < this.numTopics) {
            int type = 0;
            while (type < this.numTypes) {
                sortedWords[type] = new IDSorter(type, this.typeTopicCounts[type][topic]);
                ++type;
            }
            Arrays.sort(sortedWords);
            output.append(String.valueOf(topic) + "\t" + this.tokensPerTopic[topic] + "\t" + this.formatter.format(this.totalTopicWeights[topic]));
            int i = 0;
            while (i < numWords) {
                output.append(this.alphabet.lookupObject(((IDSorter)sortedWords[i]).getID()) + " ");
                ++i;
            }
            output.append("\n");
            ++topic;
        }
        return output.toString();
    }

    public MarginalProbEstimator getEstimator() {
        int topicBits;
        int topicMask;
        if (Integer.bitCount(this.numTopics) == 1) {
            topicMask = this.numTopics - 1;
            topicBits = Integer.bitCount(topicMask);
        } else {
            topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            topicBits = Integer.bitCount(topicMask);
        }
        int[][] sparseTypeTopicCounts = new int[this.numTypes][];
        int type = 0;
        while (type < this.numTypes) {
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            int numNonZeros = 0;
            int topic = 0;
            while (topic < this.numTopics) {
                if (currentTypeTopicCounts[topic] > 0) {
                    ++numNonZeros;
                }
                ++topic;
            }
            int[] sparseCounts = new int[numNonZeros];
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                if (currentTypeTopicCounts[topic2] > 0) {
                    int value = (currentTypeTopicCounts[topic2] << topicBits) + topic2;
                    int i = 0;
                    while (sparseCounts[i] > value) {
                        ++i;
                    }
                    while (i < sparseCounts.length && value > sparseCounts[i]) {
                        int temp = sparseCounts[i];
                        sparseCounts[i] = value;
                        value = temp;
                        ++i;
                    }
                }
                ++topic2;
            }
            sparseTypeTopicCounts[type] = sparseCounts;
            ++type;
        }
        double[] alphas = new double[this.numTopics];
        Arrays.fill(alphas, this.alpha);
        return new MarginalProbEstimator(this.numTopics, alphas, this.alphaSum, this.beta, sparseTypeTopicCounts, this.tokensPerTopic);
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream stream) {
        stream.println("#doc source pos typeindex type topic");
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            StringBuilder out = new StringBuilder();
            int position = 0;
            while (position < topicSequence.getLength()) {
                int type = tokenSequence.getIndexAtPosition(position);
                int topic = topicSequence.getIndexAtPosition(position);
                out.append(doc);
                out.append(' ');
                out.append(source);
                out.append(' ');
                out.append(position);
                out.append(' ');
                out.append(type);
                out.append(' ');
                out.append(this.alphabet.lookupObject(type));
                out.append(' ');
                out.append(topic);
                out.append("\n");
                ++position;
            }
            stream.print(out.toString());
            ++doc;
        }
    }

    public static void main(String[] args) throws Exception {
        CommandOption.setSummary(WeightedTopicModel.class, "Train topics with weights between word types encoded in the prior");
        CommandOption.process(WeightedTopicModel.class, args);
        InstanceList training = InstanceList.load(new File(WeightedTopicModel.inputFile.value));
        Randoms random = null;
        random = WeightedTopicModel.randomSeedOption.value != 0 ? new Randoms(WeightedTopicModel.randomSeedOption.value) : new Randoms();
        WeightedTopicModel lda = new WeightedTopicModel(WeightedTopicModel.numTopicsOption.value, WeightedTopicModel.alphaOption.value, WeightedTopicModel.betaOption.value, random);
        lda.addInstances(training);
        lda.readTypeTypeWeights(new File(WeightedTopicModel.weightsFile.value));
        int docCycleCount = 1;
        int epoch = 1;
        while (epoch <= WeightedTopicModel.numEpochsOption.value) {
            lda.sample(WeightedTopicModel.numIterationsOption.value, epoch == 1, docCycleCount);
            if (stateFile.wasInvoked()) {
                lda.printState(new File(String.valueOf(WeightedTopicModel.stateFile.value) + "." + epoch));
            }
            if (evaluatorFilename.wasInvoked()) {
                try {
                    ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(String.valueOf(WeightedTopicModel.evaluatorFilename.value) + "." + epoch));
                    oos.writeObject(lda.getEstimator());
                    oos.close();
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            ++epoch;
        }
    }
}

