/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.sparkml.RegressionModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.VectorUtil;
import org.jpmml.sparkml.model.HasRegressionOptions;
import org.jpmml.sparkml.model.RegressionTableUtil;

public class GeneralizedLinearRegressionModelConverter
extends RegressionModelConverter<GeneralizedLinearRegressionModel>
implements HasRegressionOptions {
    public GeneralizedLinearRegressionModelConverter(GeneralizedLinearRegressionModel model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        String family;
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        switch (family = model.getFamily()) {
            case "binomial": {
                return MiningFunction.CLASSIFICATION;
            }
        }
        return MiningFunction.REGRESSION;
    }

    @Override
    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
        List<OutputField> result = super.registerOutputFields(label, encoder);
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                result = new ArrayList<OutputField>(result);
                result.addAll(ModelUtil.createProbabilityFields((DataType)DataType.DOUBLE, (List)categoricalLabel.getValues()));
                break;
            }
        }
        return result;
    }

    public GeneralRegressionModel encodeModel(Schema schema) {
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        String targetCategory = null;
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
                if (categoricalLabel.size() != 2) {
                    throw new IllegalArgumentException();
                }
                targetCategory = categoricalLabel.getValue(1);
                break;
            }
        }
        ArrayList<Feature> features = new ArrayList<Feature>(schema.getFeatures());
        ArrayList<Double> coefficients = new ArrayList<Double>(VectorUtil.toList(model.coefficients()));
        RegressionTableUtil.simplify(this, targetCategory, features, coefficients);
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), null, null, null).setDistribution(GeneralizedLinearRegressionModelConverter.parseFamily(model.getFamily())).setLinkFunction(GeneralizedLinearRegressionModelConverter.parseLinkFunction(model.getLink())).setLinkParameter(GeneralizedLinearRegressionModelConverter.parseLinkParameter(model.getLink()));
        GeneralRegressionModelUtil.encodeRegressionTable((GeneralRegressionModel)generalRegressionModel, features, coefficients, (Double)model.intercept(), (String)targetCategory);
        return generalRegressionModel;
    }

    private static GeneralRegressionModel.Distribution parseFamily(String family) {
        switch (family) {
            case "binomial": {
                return GeneralRegressionModel.Distribution.BINOMIAL;
            }
            case "gamma": {
                return GeneralRegressionModel.Distribution.GAMMA;
            }
            case "gaussian": {
                return GeneralRegressionModel.Distribution.NORMAL;
            }
            case "poisson": {
                return GeneralRegressionModel.Distribution.POISSON;
            }
        }
        throw new IllegalArgumentException(family);
    }

    private static GeneralRegressionModel.LinkFunction parseLinkFunction(String link) {
        switch (link) {
            case "cloglog": {
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            }
            case "identity": {
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            }
            case "inverse": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
            case "log": {
                return GeneralRegressionModel.LinkFunction.LOG;
            }
            case "logit": {
                return GeneralRegressionModel.LinkFunction.LOGIT;
            }
            case "probit": {
                return GeneralRegressionModel.LinkFunction.PROBIT;
            }
            case "sqrt": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
        }
        throw new IllegalArgumentException(link);
    }

    private static Double parseLinkParameter(String link) {
        switch (link) {
            case "inverse": {
                return -1.0;
            }
            case "sqrt": {
                return 0.5;
            }
        }
        return null;
    }
}

