package org.deeplearning4j.spark.models.sequencevectors.learning.elements;

import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.class */
public class SparkSkipGram extends BaseSparkLearningAlgorithm {
    private static final Logger log = LoggerFactory.getLogger(SparkSkipGram.class);
    protected transient AtomicLong counter;
    protected transient ThreadLocal<Frame<SkipGramRequestMessage>> frame;
    protected TrainingDriver<SkipGramRequestMessage> driver = new SkipGramTrainer();

    public String getCodeName() {
        return "Spark-SkipGram";
    }

    public double learnSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong atomicLong, double d, BatchSequences<ShallowSequenceElement> batchSequences) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm
    public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong atomicLong, double d) {
        int i;
        if (this.vectorsConfiguration.getSampling() > 0.0d) {
            sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, atomicLong, 10L, this.vectorsConfiguration.getSampling());
        }
        int window = this.vectorsConfiguration.getWindow();
        if (this.vectorsConfiguration.getVariableWindows() != null && this.vectorsConfiguration.getVariableWindows().length != 0) {
            window = this.vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(0, this.vectorsConfiguration.getVariableWindows().length)];
        }
        if (this.frame == null) {
            synchronized (this) {
                if (this.frame == null) {
                    this.frame = new ThreadLocal<>();
                }
            }
        }
        if (this.frame.get() == null) {
            this.frame.set(new Frame<>(BasicSequenceProvider.getInstance().getNextValue().longValue()));
        }
        for (int i2 = 0; i2 < sequence.size(); i2++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            ShallowSequenceElement shallowSequenceElement = (ShallowSequenceElement) sequence.getElementByIndex(i2);
            if (shallowSequenceElement != null) {
                int i3 = (int) (atomicLong.get() % window);
                int i4 = ((window * 2) + 1) - i3;
                for (int i5 = i3; i5 < i4; i5++) {
                    if (i5 != window && (i = (i2 - window) + i5) >= 0 && i < sequence.size()) {
                        iterateSample(shallowSequenceElement, (ShallowSequenceElement) sequence.getElementByIndex(i), atomicLong, d);
                        atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
                    }
                }
            }
        }
        Frame<SkipGramRequestMessage> frame = this.frame.get();
        this.frame.set(new Frame<>(BasicSequenceProvider.getInstance().getNextValue().longValue()));
        return frame;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void iterateSample(ShallowSequenceElement shallowSequenceElement, ShallowSequenceElement shallowSequenceElement2, AtomicLong atomicLong, double d) {
        if (shallowSequenceElement == null || shallowSequenceElement2 == null || shallowSequenceElement2.getIndex() < 0 || shallowSequenceElement.getIndex() == shallowSequenceElement2.getIndex()) {
            return;
        }
        int[] iArr = new int[0];
        byte[] bArr = new byte[0];
        if (this.vectorsConfiguration.isUseHierarchicSoftmax()) {
            iArr = new int[shallowSequenceElement.getCodeLength()];
            bArr = new byte[shallowSequenceElement.getCodeLength()];
            for (int i = 0; i < shallowSequenceElement.getCodeLength(); i++) {
                byte byteValue = ((Byte) shallowSequenceElement.getCodes().get(i)).byteValue();
                int intValue = ((Integer) shallowSequenceElement.getPoints().get(i)).intValue();
                if (intValue < this.vocabCache.numWords() && intValue >= 0) {
                    bArr[i] = byteValue;
                    iArr[i] = intValue;
                }
            }
        }
        this.frame.get().stackMessage(new SkipGramRequestMessage(shallowSequenceElement.getIndex(), shallowSequenceElement2.getIndex(), iArr, bArr, (short) this.vectorsConfiguration.getNegative(), d, atomicLong.get()));
    }

    @Override // org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm
    public TrainingDriver<? extends TrainingMessage> getTrainingDriver() {
        return this.driver;
    }
}
