/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import com.google.common.base.Joiner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ParameterField;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.TextIndexNormalization;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.sparkml.DocumentFeature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TermFeature;
import org.jpmml.sparkml.TermUtil;

public class CountVectorizerModelConverter
extends FeatureConverter<CountVectorizerModel> {
    private static final Joiner JOINER = Joiner.on((String)"|");
    private static final AtomicInteger SEQUENCE = new AtomicInteger(1);

    public CountVectorizerModelConverter(CountVectorizerModel transformer) {
        super(transformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        CountVectorizerModel transformer = (CountVectorizerModel)this.getTransformer();
        DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(transformer.getInputCol());
        ParameterField documentField = new ParameterField(FieldName.create((String)"document"));
        ParameterField termField = new ParameterField(FieldName.create((String)"term"));
        TextIndex textIndex = new TextIndex(documentField.getName(), (Expression)new FieldRef(termField.getName())).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights((TextIndex.LocalTermWeights)(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null));
        Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
        for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
            String tokenRE;
            String wordSeparatorRE;
            if (stopWordSet.isEmpty()) continue;
            switch (wordSeparatorRE = documentFeature.getWordSeparatorRE()) {
                case "\\s+": {
                    tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join((Iterable)stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                    break;
                }
                case "\\W+": {
                    tokenRE = "(\\W+)(" + JOINER.join((Iterable)stopWordSet) + ")(\\W+)";
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
                }
            }
            LinkedHashMap<String, List<String>> data = new LinkedHashMap<String, List<String>>();
            data.put("string", Collections.singletonList(tokenRE));
            data.put("stem", Collections.singletonList(" "));
            data.put("regex", Collections.singletonList("true"));
            TextIndexNormalization textIndexNormalization = new TextIndexNormalization(null, PMMLUtil.createInlineTable(data)).setCaseSensitive(Boolean.valueOf(stopWordSet.isCaseSensitive())).setRecursive(Boolean.TRUE);
            textIndex.addTextIndexNormalizations(new TextIndexNormalization[]{textIndexNormalization});
        }
        DefineFunction defineFunction = new DefineFunction("tf@" + String.valueOf(SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, DataType.INTEGER, null, (Expression)textIndex).addParameterFields(new ParameterField[]{documentField, termField});
        encoder.addDefineFunction(defineFunction);
        ArrayList<Feature> result = new ArrayList<Feature>();
        String[] vocabulary = transformer.vocabulary();
        for (int i = 0; i < vocabulary.length; ++i) {
            String term = vocabulary[i];
            if (TermUtil.hasPunctuation(term)) {
                throw new IllegalArgumentException("Punctuated vocabulary terms (" + term + ") are not supported");
            }
            result.add(new TermFeature((PMMLEncoder)encoder, defineFunction, (Feature)documentFeature, term));
        }
        return result;
    }
}

