/*
 * Decompiled with CFR 0.152.
 */
package sklearn.compose;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.python.CastFunction;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.python.TupleUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Drop;
import sklearn.Initializer;
import sklearn.PassThrough;
import sklearn.Transformer;

public class ColumnTransformer
extends Initializer {
    public ColumnTransformer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> initializeFeatures(SkLearnEncoder encoder) {
        return this.encodeFeatures(Collections.emptyList(), encoder);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<Object[]> fittedTransformers = this.getFittedTransformers();
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (Object[] fittedTransformer : fittedTransformers) {
            Transformer transformer = ColumnTransformer.getTransformer(fittedTransformer);
            List<Feature> rowFeatures = ColumnTransformer.getFeatures(fittedTransformer, features, encoder);
            rowFeatures = transformer.encode(rowFeatures, encoder);
            result.addAll(rowFeatures);
        }
        return result;
    }

    public List<Object[]> getFittedTransformers() {
        return this.getTupleList("transformers_");
    }

    private static Transformer getTransformer(Object[] fittedTransformer) {
        Object transformer = TupleUtil.extractElement((Object[])fittedTransformer, (int)1);
        CastFunction<Transformer> castFunction = new CastFunction<Transformer>(Transformer.class){

            public Transformer apply(Object object) {
                if ("drop".equals(object)) {
                    return Drop.INSTANCE;
                }
                if ("passthrough".equals(object)) {
                    return PassThrough.INSTANCE;
                }
                return (Transformer)super.apply(object);
            }

            protected String formatMessage(Object object) {
                return "The estimator object (" + ClassDictUtil.formatClass((Object)object) + ") is not a supported Transformer";
            }
        };
        return (Transformer)castFunction.apply(transformer);
    }

    private static List<Feature> getFeatures(Object[] fittedTransformer, final List<Feature> features, final SkLearnEncoder encoder) {
        List columns = TupleUtil.extractElement((Object[])fittedTransformer, (int)2);
        if (columns instanceof String || columns instanceof Integer) {
            columns = Collections.singletonList(columns);
        } else if (columns instanceof HasArray) {
            HasArray hasArray = (HasArray)columns;
            columns = hasArray.getArrayContent();
        }
        Function<Object, Feature> castFunction = new Function<Object, Feature>(){

            public Feature apply(Object object) {
                if (object instanceof String) {
                    String column = (String)object;
                    if (features.size() > 0) {
                        for (Feature feature : features) {
                            FieldName name = feature.getName();
                            if (!column.equals(name.getValue())) continue;
                            return feature;
                        }
                        throw new IllegalArgumentException("Column '" + column + "' is undefined");
                    }
                    return this.createWildcardFeature(FieldName.create((String)column));
                }
                if (object instanceof Integer) {
                    Integer index = (Integer)object;
                    if (features.size() > 0) {
                        Feature feature = (Feature)features.get(index);
                        return feature;
                    }
                    return this.createWildcardFeature(FieldName.create((String)("x" + (index + 1))));
                }
                throw new IllegalArgumentException("The column object (" + ClassDictUtil.formatClass((Object)object) + ") is not a string or integer");
            }

            private Feature createWildcardFeature(FieldName name) {
                DataField dataField = encoder.getDataField(name);
                if (dataField == null) {
                    dataField = encoder.createDataField(name);
                }
                return new WildcardFeature((PMMLEncoder)encoder, dataField);
            }
        };
        return Lists.transform((List)columns, (Function)castFunction);
    }
}

