/*
 * Decompiled with CFR 0.152.
 */
package sklearn.pipeline;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Objects;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.converter.Schema;
import org.jpmml.python.CastFunction;
import org.jpmml.python.CastUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.TupleUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.SkLearnException;
import sklearn.Composite;
import sklearn.Estimator;
import sklearn.HasSteps;
import sklearn.PassThrough;
import sklearn.Step;
import sklearn.StepUtil;
import sklearn.Transformer;

public class SkLearnPipeline
extends Composite
implements Encodable,
HasSteps {
    public SkLearnPipeline() {
        this("sklearn.pipeline", "Pipeline");
    }

    public SkLearnPipeline(String module, String name) {
        super(module, name);
    }

    @Override
    public boolean hasTransformers() {
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            return false;
        }
        if (steps.size() == 1) {
            return !this.hasFinalEstimator();
        }
        return true;
    }

    @Override
    public boolean hasFinalEstimator() {
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            return false;
        }
        Object[] finalStep = steps.get(steps.size() - 1);
        Object step = TupleUtil.extractElement((Object[])finalStep, (int)1);
        if (step == null || Objects.equals("passthrough", step)) {
            return false;
        }
        if (step instanceof Composite) {
            Composite composite = (Composite)step;
            return composite.hasFinalEstimator();
        }
        if (step instanceof Estimator) {
            return true;
        }
        if (step instanceof Transformer) {
            return false;
        }
        if (step instanceof ClassDict) {
            ClassDict dict = (ClassDict)step;
            if (SkLearnPipeline.isEstimatorLike(dict)) {
                return true;
            }
            if (SkLearnPipeline.isTransformerLike(dict)) {
                return false;
            }
        }
        step = CastUtil.deepCastTo((Object)step, Estimator.class);
        return Estimator.class.isInstance(step);
    }

    @Override
    public List<? extends Transformer> getTransformers() {
        List<Object[]> steps = this.getSteps();
        if (this.hasFinalEstimator()) {
            steps = steps.subList(0, steps.size() - 1);
        }
        List transformers = TupleUtil.extractElementList(steps, (int)1);
        CastFunction<Transformer> castFunction = new CastFunction<Transformer>(Transformer.class){

            public Transformer apply(Object object) {
                if (object == null || Objects.equals("passthrough", object)) {
                    return PassThrough.INSTANCE;
                }
                return (Transformer)super.apply(object);
            }

            public String formatMessage(Object object) {
                return "The object (" + ClassDictUtil.formatClass((Object)object) + ") is not a supported Transformer";
            }
        };
        return Lists.transform((List)transformers, (Function)castFunction);
    }

    @Override
    public Estimator getFinalEstimator() {
        return this.getFinalEstimator(Estimator.class);
    }

    @Override
    public <E extends Estimator> E getFinalEstimator(Class<? extends E> clazz) {
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            throw new SkLearnException("Expected one or more steps, got zero steps");
        }
        Object[] finalStep = steps.get(steps.size() - 1);
        Object step = TupleUtil.extractElement((Object[])finalStep, (int)1);
        if (step == null || Objects.equals("passthrough", step)) {
            throw new SkLearnException("The pipeline ends with a transformer-like object");
        }
        CastFunction castFunction = new CastFunction<E>(clazz){

            public String formatMessage(Object object) {
                return "The object (" + ClassDictUtil.formatClass((Object)object) + ") is not a supported Estimator";
            }
        };
        return (E)((Estimator)castFunction.apply(step));
    }

    @Override
    public Step getHead() {
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            throw new SkLearnException("Expected one or more steps, got zero steps");
        }
        Object[] headStep = steps.get(0);
        Object step = TupleUtil.extractElement((Object[])headStep, (int)1);
        CastFunction<Step> castFunction = new CastFunction<Step>(Step.class){

            public Step apply(Object object) {
                if (object == null || Objects.equals("passthrough", object)) {
                    return null;
                }
                return (Step)super.apply(object);
            }

            public String formatMessage(Object object) {
                return "The object (" + ClassDictUtil.formatClass((Object)object) + ") is not a supported Transformer or Estimator";
            }
        };
        step = castFunction.apply(step);
        return StepUtil.getHead((Step)step);
    }

    @Override
    public PMML encodePMML() {
        SkLearnEncoder encoder = new SkLearnEncoder();
        Estimator estimator = null;
        if (this.hasFinalEstimator()) {
            estimator = this.getFinalEstimator();
            this.initLabel(null, encoder);
        }
        this.initFeatures(null, encoder);
        if (estimator == null) {
            return encoder.encodePMML(null);
        }
        Schema schema = encoder.createSchema();
        Model model = estimator.encode(schema);
        encoder.setModel(model);
        return encoder.encodePMML(model);
    }

    @Override
    public List<Object[]> getSteps() {
        return this.getTupleList("steps");
    }

    protected SkLearnPipeline setSteps(List<Object[]> steps) {
        this.setattr("steps", steps);
        return this;
    }

    private static boolean isEstimatorLike(ClassDict dict) {
        String name = dict.getClassName();
        if (name.endsWith("Estimator")) {
            return true;
        }
        if (name.endsWith("Classifier") || name.endsWith("Regressor")) {
            return true;
        }
        if (dict.containsKey((Object)"n_outputs_")) {
            return true;
        }
        return dict.containsKey((Object)"n_classes_") || dict.containsKey((Object)"classes_");
    }

    private static boolean isTransformerLike(ClassDict dict) {
        String name = dict.getClassName();
        return name.endsWith("Transformer");
    }
}

