package org.deeplearning4j.spark.models.sequencevectors.functions;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkSequenceLearningAlgorithm;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.class */
public class PartitionTrainingFunction<T extends SequenceElement> implements VoidFunction<Iterator<Sequence<T>>> {
    protected Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast;
    protected Broadcast<VectorsConfiguration> configurationBroadcast;
    protected Broadcast<VoidConfiguration> paramServerConfigurationBroadcast;
    protected transient VoidParameterServer paramServer;
    protected transient VectorsConfiguration vectorsConfiguration;
    protected transient SparkElementsLearningAlgorithm elementsLearningAlgorithm;
    protected transient SparkSequenceLearningAlgorithm sequenceLearningAlgorithm;
    protected transient VocabCache<ShallowSequenceElement> shallowVocabCache;
    protected transient TrainingDriver<? extends TrainingMessage> driver;

    public PartitionTrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> broadcast, @NonNull Broadcast<VectorsConfiguration> broadcast2, @NonNull Broadcast<VoidConfiguration> broadcast3) {
        if (broadcast == null) {
            throw new NullPointerException("vocabCacheBroadcast is marked non-null but is null");
        }
        if (broadcast2 == null) {
            throw new NullPointerException("vectorsConfigurationBroadcast is marked non-null but is null");
        }
        if (broadcast3 == null) {
            throw new NullPointerException("paramServerConfigurationBroadcast is marked non-null but is null");
        }
        this.vocabCacheBroadcast = broadcast;
        this.configurationBroadcast = broadcast2;
        this.paramServerConfigurationBroadcast = broadcast3;
    }

    public void call(Iterator<Sequence<T>> it) throws Exception {
        if (this.vectorsConfiguration == null) {
            this.vectorsConfiguration = (VectorsConfiguration) this.configurationBroadcast.getValue();
        }
        if (this.paramServer == null) {
            this.paramServer = VoidParameterServer.getInstance();
            if (this.elementsLearningAlgorithm == null) {
                try {
                    this.elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(this.vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            this.driver = this.elementsLearningAlgorithm.getTrainingDriver();
            this.paramServer.init((VoidConfiguration) this.paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), this.driver);
        }
        if (this.shallowVocabCache == null) {
            this.shallowVocabCache = (VocabCache) this.vocabCacheBroadcast.getValue();
        }
        if (this.elementsLearningAlgorithm == null && this.vectorsConfiguration.getElementsLearningAlgorithm() != null) {
            try {
                this.elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(this.vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        }
        if (this.elementsLearningAlgorithm != null) {
            this.elementsLearningAlgorithm.configure(this.shallowVocabCache, null, this.vectorsConfiguration);
        }
        if (this.sequenceLearningAlgorithm == null && this.vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
            try {
                this.sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(this.vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
                this.sequenceLearningAlgorithm.configure(this.shallowVocabCache, null, this.vectorsConfiguration);
            } catch (Exception e3) {
                throw new RuntimeException(e3);
            }
        }
        if (this.sequenceLearningAlgorithm != null) {
            this.sequenceLearningAlgorithm.configure(this.shallowVocabCache, null, this.vectorsConfiguration);
        }
        if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
            throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
        }
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            Sequence<T> next = it.next();
            Sequence<ShallowSequenceElement> sequence = new Sequence<>();
            Iterator it2 = next.getElements().iterator();
            while (it2.hasNext()) {
                ShallowSequenceElement shallowSequenceElement = this.shallowVocabCache.tokenFor(((SequenceElement) it2.next()).getStorageId().longValue());
                if (shallowSequenceElement != null) {
                    sequence.addElement(shallowSequenceElement);
                }
            }
            if (this.sequenceLearningAlgorithm != null && this.vectorsConfiguration.isTrainSequenceVectors()) {
                Iterator it3 = next.getSequenceLabels().iterator();
                while (it3.hasNext()) {
                    ShallowSequenceElement shallowSequenceElement2 = this.shallowVocabCache.tokenFor(((SequenceElement) it3.next()).getStorageId().longValue());
                    if (shallowSequenceElement2 != null) {
                        sequence.addSequenceLabel(shallowSequenceElement2);
                    }
                }
            }
            arrayList.add(sequence);
            if (arrayList.size() >= 8) {
                trainAllAtOnce(arrayList);
                arrayList.clear();
            }
        }
        if (arrayList.isEmpty()) {
            return;
        }
        trainAllAtOnce(arrayList);
        arrayList.clear();
    }

    protected void trainAllAtOnce(List<Sequence<ShallowSequenceElement>> list) {
        Frame frame = new Frame(BasicSequenceProvider.getInstance().getNextValue().longValue());
        Iterator<Sequence<ShallowSequenceElement>> it = list.iterator();
        while (it.hasNext()) {
            frame.stackMessages(this.elementsLearningAlgorithm.frameSequence(it.next(), new AtomicLong(119L), 0.02500000037252903d).getMessages());
        }
        if (frame.size() > 0) {
            this.paramServer.execDistributed(frame);
        }
    }
}
