/*
 * Decompiled with CFR 0.152.
 */
package sklearn.feature_extraction.text;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import numpy.core.ScalarUtil;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ParameterField;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.TextIndexNormalization;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ObjectFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.StringFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.HasSparseOutput;
import sklearn.SkLearnTransformer;
import sklearn.feature_extraction.text.Tokenizer;
import sklearn2pmml.feature_extraction.text.Matcher;

public class CountVectorizer
extends SkLearnTransformer
implements HasSparseOutput {
    public static final String TOKEN_PATTERN = "(?u)\\b\\w\\w+\\b";

    public CountVectorizer(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        return 1;
    }

    @Override
    public OpType getOpType() {
        return OpType.CATEGORICAL;
    }

    @Override
    public DataType getDataType() {
        return DataType.STRING;
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        DataType dataType;
        Boolean lowercase = this.getLowercase();
        Map<String, ?> vocabulary = this.getVocabulary();
        ClassDictUtil.checkSize((int)1, (Collection[])new Collection[]{features});
        Feature feature = features.get(0);
        HashBiMap termIndexMap = HashBiMap.create((int)vocabulary.size());
        Set<Map.Entry<String, ?>> entries = vocabulary.entrySet();
        for (Map.Entry entry : entries) {
            termIndexMap.put(entry.getKey(), (Object)ValueUtil.asInteger((Number)((Number)ScalarUtil.decode(entry.getValue()))));
        }
        BiMap indexTermMap = termIndexMap.inverse();
        TypeInfo typeInfo = this.getDType();
        DataType dataType2 = dataType = typeInfo != null ? typeInfo.getDataType() : DataType.DOUBLE;
        if (lowercase.booleanValue()) {
            Apply apply = ExpressionUtil.createApply((String)"lowercase", (Expression[])new Expression[]{feature.ref()});
            DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create((String)"lowercase", (Object[])new Object[]{feature}), OpType.CATEGORICAL, DataType.STRING, () -> apply);
            feature = new StringFeature((PMMLEncoder)encoder, (Field)derivedField);
        }
        DefineFunction defineFunction = this.encodeDefineFunction(feature, encoder);
        encoder.addDefineFunction(defineFunction);
        ArrayList<Feature> result = new ArrayList<Feature>();
        int max = indexTermMap.size();
        for (int i = 0; i < max; ++i) {
            String term = (String)indexTermMap.get((Object)i);
            final Apply apply = this.encodeApply(defineFunction, feature, i, term);
            ObjectFeature termFeature = new ObjectFeature((PMMLEncoder)encoder, FieldNameUtil.create((String)this.functionName(), (Object[])new Object[]{feature, term}), dataType){

                public ContinuousFeature toContinuousFeature() {
                    return this.toContinuousFeature(this.getName(), this.getDataType(), () -> apply);
                }
            };
            result.add((Feature)termFeature);
        }
        return result;
    }

    public DefineFunction encodeDefineFunction(Feature feature, SkLearnEncoder encoder) {
        String stopWordsRE;
        String analyzer = this.getAnalyzer();
        List<String> stopWords = this.getStopWords();
        Object[] nGramRange = this.getNGramRange();
        Boolean binary = this.getBinary();
        Object preprocessor = this.getPreprocessor();
        String stripAccents = this.getStripAccents();
        Tokenizer tokenizer = this.getTokenizer();
        switch (analyzer) {
            case "word": {
                break;
            }
            default: {
                throw new IllegalArgumentException(analyzer);
            }
        }
        if (preprocessor != null) {
            throw new IllegalArgumentException();
        }
        if (stripAccents != null) {
            throw new IllegalArgumentException(stripAccents);
        }
        if (tokenizer == null) {
            String tokenPattern = this.getTokenPattern();
            tokenizer = new Matcher().setWordRE(tokenPattern);
        }
        ParameterField documentField = new ParameterField("document");
        ParameterField termField = new ParameterField("term");
        TextIndex textIndex = new TextIndex((Field)documentField, (Expression)new FieldRef((Field)termField)).setLocalTermWeights((TextIndex.LocalTermWeights)(binary != false ? TextIndex.LocalTermWeights.BINARY : null));
        textIndex = tokenizer.configure(textIndex);
        if (stopWords != null && !stopWords.isEmpty() && !Arrays.equals(nGramRange, (Object[])new Integer[]{1, 1}) && (stopWordsRE = tokenizer.formatStopWordsRE(stopWords)) != null) {
            LinkedHashMap<String, List<String>> data = new LinkedHashMap<String, List<String>>();
            data.put("string", Collections.singletonList(stopWordsRE));
            data.put("stem", Collections.singletonList(" "));
            data.put("regex", Collections.singletonList("true"));
            TextIndexNormalization textIndexNormalization = new TextIndexNormalization(PMMLUtil.createInlineTable(data)).setRecursive(Boolean.TRUE);
            textIndex.addTextIndexNormalizations(new TextIndexNormalization[]{textIndexNormalization});
        }
        String name = this.createFieldName(this.functionName(), feature);
        DefineFunction defineFunction = new DefineFunction(name, OpType.CONTINUOUS, DataType.INTEGER, null, (Expression)textIndex).addParameterFields(new ParameterField[]{documentField, termField});
        return defineFunction;
    }

    public Apply encodeApply(DefineFunction defineFunction, Feature feature, int index, String term) {
        Constant constant = ExpressionUtil.createConstant((DataType)DataType.STRING, (Object)term);
        return ExpressionUtil.createApply((DefineFunction)defineFunction, (Expression[])new Expression[]{feature.ref(), constant});
    }

    public String functionName() {
        return "tf";
    }

    public String getAnalyzer() {
        return this.getString("analyzer");
    }

    public Boolean getBinary() {
        return this.getBoolean("binary");
    }

    public TypeInfo getDType() {
        return this.getDType("dtype", false);
    }

    public Boolean getLowercase() {
        return this.getBoolean("lowercase");
    }

    public Object[] getNGramRange() {
        return this.getTuple("ngram_range");
    }

    public Object getPreprocessor() {
        return this.getOptionalObject("preprocessor");
    }

    @Override
    public Boolean getSparseOutput() {
        return Boolean.TRUE;
    }

    public List<String> getStopWords() {
        Object stopWords = this.getOptionalObject("stop_words");
        if (stopWords instanceof String) {
            return CountVectorizer.loadStopWords((String)stopWords);
        }
        return (List)stopWords;
    }

    public String getStripAccents() {
        return this.getOptionalString("strip_accents");
    }

    public Tokenizer getTokenizer() {
        return (Tokenizer)((Object)this.getOptional("tokenizer", Tokenizer.class));
    }

    public String getTokenPattern() {
        return this.getString("token_pattern");
    }

    public Map<String, ?> getVocabulary() {
        return this.getDict("vocabulary_");
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static List<String> loadStopWords(String stopWords) {
        InputStream is = CountVectorizer.class.getResourceAsStream("/stop_words/" + stopWords + ".txt");
        if (is == null) {
            throw new IllegalArgumentException(stopWords);
        }
        try (InputStreamReader reader = new InputStreamReader(is, "UTF-8");){
            List list = CharStreams.readLines((Readable)reader);
            return list;
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(stopWords, ioe);
        }
    }
}

