/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.expression;

import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.Scope;
import sklearn.Classifier;
import sklearn2pmml.expression.ExpressionUtil;
import sklearn2pmml.util.EvaluatableUtil;
import sklearn2pmml.util.Expression;

public class ExpressionClassifier
extends Classifier {
    public ExpressionClassifier(String module, String name) {
        super(module, name);
    }

    /*
     * WARNING - void declaration
     */
    public RegressionModel encodeModel(Schema schema) {
        void var11_14;
        Map<?, Expression> classExprs = this.getClassExprs();
        RegressionModel.NormalizationMethod normalizationMethod = ExpressionClassifier.parseNormalizationMethod(this.getNormalizationMethod());
        PMMLEncoder encoder = schema.getEncoder();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        DataFrameScope scope = new DataFrameScope("X", features, encoder);
        LinkedHashMap categoryRegressionTables = new LinkedHashMap();
        Set<Map.Entry<?, Expression>> entries = classExprs.entrySet();
        for (Map.Entry entry : entries) {
            Object category2 = entry.getKey();
            Expression expr = (Expression)((Object)entry.getValue());
            org.dmg.pmml.Expression pmmlExpression = EvaluatableUtil.translateExpression((Object)expr, (Scope)scope);
            ContinuousFeature exprFeature = ExpressionUtil.toFeature(FieldNameUtil.create((String)"expression", (Object[])new Object[]{category2}), pmmlExpression, encoder);
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(exprFeature), Collections.singletonList(1.0), (Number)0.0);
            categoryRegressionTables.put(category2, regressionTable);
        }
        List categories = categoricalLabel.getValues();
        switch (normalizationMethod) {
            case LOGIT: {
                Object passiveCategory;
                if (categoryRegressionTables.size() != 1 || categories.size() != 2) {
                    throw new IllegalArgumentException();
                }
                Object activeCategory = Iterables.getOnlyElement(categoryRegressionTables.keySet());
                int index = categories.indexOf(activeCategory);
                if (index == 0) {
                    passiveCategory = categories.get(1);
                } else if (index == 1) {
                    passiveCategory = categories.get(0);
                } else {
                    throw new IllegalArgumentException();
                }
                RegressionTable activeRegressionTable = ((RegressionTable)categoryRegressionTables.get(activeCategory)).setTargetCategory(activeCategory);
                RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), null).setTargetCategory(passiveCategory);
                List<RegressionTable> list = Arrays.asList(activeRegressionTable, passiveRegressionTable);
                break;
            }
            case SIMPLEMAX: 
            case SOFTMAX: {
                if (categoryRegressionTables.size() != categories.size() || !categoryRegressionTables.keySet().containsAll(categories)) {
                    throw new IllegalArgumentException();
                }
                List list = categories.stream().map(category -> ((RegressionTable)categoryRegressionTables.get(category)).setTargetCategory(category)).collect(Collectors.toList());
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), (List)var11_14).setNormalizationMethod(normalizationMethod);
        this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return regressionModel;
    }

    public Map<?, Expression> getClassExprs() {
        return this.getDict("class_exprs");
    }

    public String getNormalizationMethod() {
        return this.getString("normalization_method");
    }

    private static RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod) {
        switch (normalizationMethod) {
            case "logit": {
                return RegressionModel.NormalizationMethod.LOGIT;
            }
            case "simplemax": {
                return RegressionModel.NormalizationMethod.SIMPLEMAX;
            }
            case "softmax": {
                return RegressionModel.NormalizationMethod.SOFTMAX;
            }
        }
        throw new IllegalArgumentException(normalizationMethod);
    }
}

