/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.Label;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.Scope;
import org.jpmml.python.TupleUtil;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasClasses;
import sklearn.HasEstimatorEnsemble;
import sklearn2pmml.ensemble.Link;
import sklearn2pmml.util.EvaluatableUtil;

public class EstimatorChain
extends Estimator
implements HasClasses,
HasEstimatorEnsemble<Estimator> {
    public EstimatorChain(String module, String name) {
        super(module, name);
    }

    @Override
    public MiningFunction getMiningFunction() {
        List<? extends Estimator> estimators = this.getEstimators();
        return EstimatorUtil.getMiningFunction(estimators);
    }

    @Override
    public int getNumberOfOutputs() {
        Boolean multioutput = this.getMultioutput();
        if (multioutput.booleanValue()) {
            List<? extends Estimator> estimators = this.getEstimators();
            return estimators.size();
        }
        return 1;
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public List<?> getClasses() {
        List<? extends Estimator> estimators = this.getEstimators();
        if (estimators.size() == 1) {
            Estimator estimator = estimators.get(0);
            return EstimatorUtil.getClasses(estimator);
        }
        if (estimators.size() >= 2) {
            ArrayList result = new ArrayList();
            for (Estimator estimator : estimators) {
                result.add(EstimatorUtil.getClasses(estimator));
            }
            List uniqueResults = result.stream().distinct().collect(Collectors.toList());
            if (uniqueResults.size() == 1) {
                return (List)uniqueResults.get(0);
            }
            return result;
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Model encodeModel(Schema schema) {
        MultiLabel multiLabel;
        Boolean multioutput = this.getMultioutput();
        final List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        if (label instanceof ScalarLabel) {
            final ScalarLabel scalarLabel = (ScalarLabel)label;
            AbstractList labels = new AbstractList(){

                @Override
                public int size() {
                    return steps.size();
                }

                @Override
                public ScalarLabel get(int index) {
                    return scalarLabel;
                }
            };
            multiLabel = new MultiLabel((List)labels);
        } else if (label instanceof MultiLabel) {
            multiLabel = (MultiLabel)label;
        } else {
            throw new IllegalArgumentException();
        }
        ArrayList<Estimator> estimators = new ArrayList<Estimator>();
        ArrayList<Model> models = new ArrayList<Model>();
        Segmentation segmentation = new Segmentation(multioutput != false ? Segmentation.MultipleModelMethod.MULTI_MODEL_CHAIN : Segmentation.MultipleModelMethod.MODEL_CHAIN, null);
        DataFrameScope scope = new DataFrameScope("X", features);
        for (int i = 0; i < steps.size(); ++i) {
            Object[] step = steps.get(i);
            String name = (String)TupleUtil.extractElement((Object[])step, (int)0, String.class);
            Estimator estimator = (Estimator)TupleUtil.extractElement((Object[])step, (int)1, Estimator.class);
            Object expr = TupleUtil.extractElement((Object[])step, (int)2, Object.class);
            estimators.add(estimator);
            Schema segmentSchema = schema.toRelabeledSchema(multiLabel.getLabel(i));
            Predicate predicate = EvaluatableUtil.translatePredicate(expr, (Scope)scope);
            Model model = estimator.encode(segmentSchema);
            models.add(model);
            if (estimator instanceof Link) {
                Link link = (Link)estimator;
                schema = link.augmentSchema(model, segmentSchema);
            }
            Segment segment = new Segment(predicate, model).setId(name);
            segmentation.addSegments(new Segment[]{segment});
        }
        MiningFunction miningFunction = EstimatorUtil.getMiningFunction(estimators);
        MiningModel miningModel = new MiningModel(miningFunction, MiningModelUtil.createMiningSchema(models)).setSegmentation(segmentation);
        return miningModel;
    }

    @Override
    public List<? extends Estimator> getEstimators() {
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        return TupleUtil.extractElementList(steps, (int)1, Estimator.class);
    }

    public Boolean getMultioutput() {
        return this.getBoolean("multioutput");
    }

    public List<Object[]> getSteps() {
        return this.getTupleList("steps");
    }
}

