package org.deeplearning4j.spark.models.sequencevectors.export.impl;

import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer;
import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.class */
public class VocabCacheExporter implements SparkModelExporter<VocabWord> {
    private static final Logger log = LoggerFactory.getLogger(VocabCacheExporter.class);
    protected VocabCache<VocabWord> vocabCache;
    protected InMemoryLookupTable<VocabWord> lookupTable;
    protected Word2Vec word2Vec;

    @Override // org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter
    public void export(JavaRDD<ExportContainer<VocabWord>> javaRDD) {
        List<ExportContainer> collect = javaRDD.collect();
        if (this.vocabCache == null) {
            this.vocabCache = new AbstractCache();
        }
        INDArray iNDArray = null;
        for (ExportContainer exportContainer : collect) {
            VocabWord element = exportContainer.getElement();
            INDArray array = exportContainer.getArray();
            if (iNDArray == null) {
                iNDArray = Nd4j.create(new long[]{collect.size(), array.length()});
            }
            this.vocabCache.addToken(element);
            this.vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
            iNDArray.getRow(element.getIndex()).assign(array);
        }
        if (this.lookupTable == null) {
            this.lookupTable = new InMemoryLookupTable.Builder().cache(this.vocabCache).vectorLength(iNDArray.columns()).build();
        }
        this.lookupTable.setSyn0(iNDArray);
        this.word2Vec = WordVectorSerializer.fromPair(Pair.makePair(this.lookupTable, this.vocabCache));
    }

    public VocabCache<VocabWord> getVocabCache() {
        return this.vocabCache;
    }

    public InMemoryLookupTable<VocabWord> getLookupTable() {
        return this.lookupTable;
    }

    public Word2Vec getWord2Vec() {
        return this.word2Vec;
    }
}
