/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.Scope;
import org.jpmml.python.TupleUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasClasses;
import sklearn.Transformer;
import sklearn2pmml.ensemble.HasEstimatorSteps;
import sklearn2pmml.ensemble.Link;
import sklearn2pmml.util.EvaluatableUtil;

public class EstimatorChain
extends Estimator
implements HasClasses,
HasEstimatorSteps {
    public EstimatorChain(String module, String name) {
        super(module, name);
    }

    @Override
    public MiningFunction getMiningFunction() {
        List estimators = this.getEstimators();
        return EstimatorUtil.getMiningFunction(estimators);
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public int getNumberOfOutputs() {
        Boolean multioutput = this.getMultioutput();
        if (multioutput.booleanValue()) {
            List estimators = this.getEstimators();
            return estimators.size();
        }
        return 1;
    }

    @Override
    public List<?> getClasses() {
        List estimators = this.getEstimators();
        if (estimators.size() == 1) {
            Estimator estimator = (Estimator)estimators.get(0);
            return EstimatorUtil.getClasses(estimator);
        }
        if (estimators.size() >= 2) {
            ArrayList result = new ArrayList();
            for (Estimator estimator : estimators) {
                result.add(EstimatorUtil.getClasses(estimator));
            }
            List uniqueResults = result.stream().distinct().collect(Collectors.toList());
            if (uniqueResults.size() == 1) {
                return (List)uniqueResults.get(0);
            }
            return result;
        }
        throw new IllegalArgumentException();
    }

    @Override
    public boolean hasProbabilityDistribution() {
        List estimators = this.getEstimators();
        boolean result = true;
        for (Estimator estimator : estimators) {
            result &= EstimatorUtil.hasProbabilityDistribution(estimator);
        }
        return result;
    }

    @Override
    public Label encodeLabel(List<String> names, SkLearnEncoder encoder) {
        Boolean multioutput = this.getMultioutput();
        List estimators = this.getEstimators();
        if (multioutput.booleanValue()) {
            ClassDictUtil.checkSize((Collection[])new Collection[]{names, estimators});
        }
        if (names.size() == 1) {
            String name = names.get(0);
            Estimator estimator = (Estimator)estimators.get(0);
            return estimator.encodeLabel(Collections.singletonList(name), encoder);
        }
        if (names.size() >= 2) {
            ArrayList<ScalarLabel> labels = new ArrayList<ScalarLabel>();
            for (int i = 0; i < names.size(); ++i) {
                String name = names.get(i);
                Estimator estimator = (Estimator)estimators.get(i);
                ScalarLabel label = (ScalarLabel)estimator.encodeLabel(Collections.singletonList(name), encoder);
                labels.add(label);
            }
            return new MultiLabel(labels);
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Model encodeModel(Schema schema) {
        Transformer controller = this.getController();
        Boolean multioutput = this.getMultioutput();
        List<Object[]> steps = this.getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        Label label = schema.getLabel();
        List<Feature> features = schema.getFeatures();
        List scalarLabels = ScalarLabelUtil.toScalarLabels((Label)label);
        if (multioutput.booleanValue()) {
            ClassDictUtil.checkSize((Collection[])new Collection[]{steps, scalarLabels});
        }
        ArrayList<Estimator> estimators = new ArrayList<Estimator>();
        ArrayList<Model> models = new ArrayList<Model>();
        Segmentation segmentation = new Segmentation(multioutput != false ? Segmentation.MultipleModelMethod.MULTI_MODEL_CHAIN : Segmentation.MultipleModelMethod.MODEL_CHAIN, null);
        List<Feature> controlFeatures = features;
        if (controller != null) {
            controlFeatures = controller.encode(controlFeatures, encoder);
        }
        DataFrameScope scope = new DataFrameScope("X", controlFeatures);
        for (int i = 0; i < steps.size(); ++i) {
            Object[] step = steps.get(i);
            ScalarLabel scalarLabel = multioutput != false ? (ScalarLabel)scalarLabels.get(i) : (ScalarLabel)scalarLabels.get(0);
            String name = (String)TupleUtil.extractElement((Object[])step, (int)0, String.class);
            Estimator estimator = (Estimator)TupleUtil.extractElement((Object[])step, (int)1, Estimator.class);
            Object expr = TupleUtil.extractElement((Object[])step, (int)2, Object.class);
            estimators.add(estimator);
            Schema segmentSchema = schema.toRelabeledSchema((Label)scalarLabel);
            Predicate predicate = EvaluatableUtil.translatePredicate(expr, (Scope)scope);
            Model model = multioutput != false ? estimator.encode(scalarLabel.getName(), segmentSchema) : estimator.encode(segmentSchema);
            models.add(model);
            if (estimator instanceof Link) {
                Link link = (Link)estimator;
                schema = link.augmentSchema(model, segmentSchema);
            }
            Segment segment = new Segment(predicate, model).setId(name);
            segmentation.addSegments(new Segment[]{segment});
        }
        MiningFunction miningFunction = EstimatorUtil.getMiningFunction(estimators);
        MiningModel miningModel = new MiningModel(miningFunction, MiningModelUtil.createMiningSchema(models)).setSegmentation(segmentation);
        return miningModel;
    }

    @Override
    public Transformer getController() {
        return (Transformer)this.getOptional("controller", Transformer.class);
    }

    public Boolean getMultioutput() {
        return this.getBoolean("multioutput");
    }

    @Override
    public List<Object[]> getSteps() {
        return this.getTupleList("steps");
    }
}

