/*
 * Decompiled with CFR 0.152.
 */
package sklearn.preprocessing;

import com.google.common.collect.ContiguousSet;
import com.google.common.collect.DiscreteDomain;
import com.google.common.collect.Range;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;
import sklearn.TypeUtil;

public class OneHotEncoder
extends Transformer {
    public OneHotEncoder(String module, String name) {
        super(module, name);
    }

    @Override
    public OpType getOpType() {
        return OpType.CATEGORICAL;
    }

    @Override
    public DataType getDataType() {
        List<? extends Number> values = this.getValues();
        return TypeUtil.getDataType(values, DataType.INTEGER);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<? extends Number> values = this.getValues();
        ClassDictUtil.checkSize(1, features);
        Feature feature = features.get(0);
        ArrayList<Feature> result = new ArrayList<Feature>();
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            ClassDictUtil.checkSize(values, categoricalFeature.getValues());
            for (int i = 0; i < values.size(); ++i) {
                result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, (Feature)categoricalFeature, categoricalFeature.getValue(i)));
            }
        } else if (feature instanceof WildcardFeature) {
            WildcardFeature wildcardFeature = (WildcardFeature)feature;
            ArrayList<String> categories = new ArrayList<String>();
            for (int i = 0; i < values.size(); ++i) {
                int value = ValueUtil.asInt((Number)values.get(i));
                String category = ValueUtil.formatValue((Number)value);
                categories.add(category);
                result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, (Feature)wildcardFeature, category));
            }
            wildcardFeature.toCategoricalFeature(categories);
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    public List<? extends Number> getValues() {
        List<Integer> featureSizes = this.getFeatureSizes();
        ClassDictUtil.checkSize(1, featureSizes);
        Object numberOfValues = this.get("n_values");
        if ("auto".equals(numberOfValues)) {
            return this.getActiveFeatures();
        }
        Integer featureSize = featureSizes.get(0);
        ArrayList result = new ArrayList();
        result.addAll(ContiguousSet.create((Range)Range.closedOpen((Comparable)Integer.valueOf(0), (Comparable)featureSize), (DiscreteDomain)DiscreteDomain.integers()));
        return result;
    }

    public List<? extends Number> getActiveFeatures() {
        return this.getArray("active_features_", Number.class);
    }

    public List<Integer> getFeatureSizes() {
        return ValueUtil.asIntegers(this.getArray("n_values_", Number.class));
    }
}

