/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.WordEmbeddings;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.util.Random;

public class WordEmbeddingRunnable
implements Runnable {
    public WordEmbeddings model;
    public InstanceList instances;
    public int numSamples;
    public boolean shouldRun = true;
    double residual = 0.0;
    int numUpdates = 0;
    int numThreads;
    int threadID;
    int stride;
    public int docID;
    public Random random;
    int numColumns;
    public int wordsSoFar = 0;

    public WordEmbeddingRunnable(WordEmbeddings model, InstanceList instances, int numSamples, int numThreads, int threadID) {
        this.model = model;
        this.stride = model.stride;
        this.instances = instances;
        this.numSamples = numSamples;
        this.numThreads = numThreads;
        this.threadID = threadID;
        this.random = new Random();
        this.numColumns = model.numColumns;
    }

    public double getMeanError() {
        if (this.numUpdates == 0) {
            return this.docID;
        }
        double result = this.residual / (double)this.numUpdates;
        this.residual = 0.0;
        this.numUpdates = 0;
        return result;
    }

    @Override
    public void run() {
        int numDocuments = this.instances.size();
        double sampleNormalizer = 1.0f / (float)this.numSamples;
        double[] gradient = new double[this.numColumns];
        int queryID = this.model.vocabulary.lookupIndex(this.model.queryWord);
        int outputOffset = this.model.numColumns;
        this.docID = this.threadID * (numDocuments / this.numThreads);
        int maxDocID = (this.threadID + 1) * (numDocuments / this.numThreads);
        if (maxDocID > numDocuments) {
            maxDocID = numDocuments;
        }
        double cacheScale = 1.0 / (this.model.maxExpValue - this.model.minExpValue);
        int[] tokenBuffer = new int[100000];
        while (this.shouldRun) {
            int inputType;
            Instance instance = (Instance)this.instances.get(this.docID);
            ++this.docID;
            if (this.docID == maxDocID) {
                this.docID = this.threadID * (numDocuments / this.numThreads);
            }
            double learningRate = Math.max(1.0E-4, 0.025 * (1.0 - (double)this.numThreads * (double)this.wordsSoFar / (double)this.model.totalWords));
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            int originalLength = tokens.getLength();
            int length = 0;
            int inputPosition = 0;
            while (inputPosition < originalLength) {
                inputType = tokens.getIndexAtPosition(inputPosition);
                ++this.wordsSoFar;
                double frequencyScore = (double)this.model.wordCounts[inputType] / (1.0E-4 * (double)this.model.totalWords);
                if (this.random.nextDouble() < (Math.sqrt(frequencyScore) + 1.0) / frequencyScore) {
                    tokenBuffer[length] = inputType;
                    ++length;
                }
                ++inputPosition;
            }
            if (length < 10) continue;
            inputPosition = 0;
            while (inputPosition < length) {
                inputType = tokenBuffer[inputPosition];
                int subWindow = this.model.windowSize;
                int start = Math.max(0, inputPosition - subWindow);
                int end = Math.min(length - 1, inputPosition + subWindow);
                int outputPosition = start;
                while (outputPosition <= end) {
                    int outputType;
                    if (inputPosition != outputPosition && inputType != (outputType = tokenBuffer[outputPosition])) {
                        double innerProduct = this.model.weights[inputType * this.stride + 0] + this.model.weights[inputType * this.stride + outputOffset];
                        int col = 1;
                        while (col < this.numColumns) {
                            innerProduct += this.model.weights[inputType * this.stride + col] * this.model.weights[outputType * this.stride + outputOffset + col];
                            ++col;
                        }
                        double prediction = innerProduct < this.model.minExpValue ? 0.0 : (innerProduct > this.model.maxExpValue ? 1.0 : this.model.sigmoidCache[(int)Math.floor((double)this.model.sigmoidCacheSize * (innerProduct - this.model.minExpValue) * cacheScale)]);
                        gradient[0] = 1.0 - prediction;
                        int n = outputType * this.stride + outputOffset;
                        this.model.weights[n] = this.model.weights[n] + learningRate * (1.0 - prediction);
                        int col2 = 1;
                        while (col2 < this.numColumns) {
                            gradient[col2] = (1.0 - prediction) * this.model.weights[outputType * this.stride + outputOffset + col2];
                            int n2 = outputType * this.stride + outputOffset + col2;
                            this.model.weights[n2] = this.model.weights[n2] + learningRate * ((1.0 - prediction) * this.model.weights[inputType * this.stride + col2]);
                            ++col2;
                        }
                        double meanNegativePrediction = 0.0;
                        int sample = 0;
                        while (sample < this.numSamples) {
                            int sampledType = this.model.samplingTable[this.random.nextInt(this.model.samplingTableSize)];
                            int sampledTypeOffset = sampledType * this.stride;
                            innerProduct = this.model.weights[inputType * this.stride + 0] + this.model.weights[sampledTypeOffset + outputOffset];
                            int col3 = 0;
                            while (col3 < this.numColumns) {
                                innerProduct += this.model.weights[inputType * this.stride + col3] * this.model.weights[sampledTypeOffset + outputOffset + col3];
                                ++col3;
                            }
                            double negativePrediction = 0.0;
                            negativePrediction = innerProduct < this.model.minExpValue ? 0.0 : (innerProduct > this.model.maxExpValue ? 1.0 : this.model.sigmoidCache[(int)Math.floor((double)this.model.sigmoidCacheSize * (innerProduct - this.model.minExpValue) * cacheScale)]);
                            meanNegativePrediction += negativePrediction;
                            gradient[0] = gradient[0] + sampleNormalizer * -negativePrediction;
                            int n3 = sampledTypeOffset + outputOffset;
                            this.model.weights[n3] = this.model.weights[n3] + learningRate * sampleNormalizer * -negativePrediction;
                            int col4 = 1;
                            while (col4 < this.numColumns) {
                                int n4 = col4;
                                gradient[n4] = gradient[n4] + sampleNormalizer * (-negativePrediction * this.model.weights[sampledType * this.stride + outputOffset + col4]);
                                int n5 = sampledTypeOffset + outputOffset + col4;
                                this.model.weights[n5] = this.model.weights[n5] + learningRate * sampleNormalizer * (-negativePrediction * this.model.weights[inputType * this.stride + col4]);
                                ++col4;
                            }
                            ++sample;
                        }
                        this.residual += prediction - meanNegativePrediction * sampleNormalizer;
                        ++this.numUpdates;
                        int col5 = 0;
                        while (col5 < this.numColumns) {
                            int n6 = inputType * this.stride + col5;
                            this.model.weights[n6] = this.model.weights[n6] + learningRate * gradient[col5];
                            ++col5;
                        }
                    }
                    ++outputPosition;
                }
                ++inputPosition;
            }
        }
    }
}

