/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.List;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.sparkml.RegressionModelConverter;
import org.jpmml.sparkml.VectorUtil;

public class GeneralizedLinearRegressionModelConverter
extends RegressionModelConverter<GeneralizedLinearRegressionModel> {
    public GeneralizedLinearRegressionModelConverter(GeneralizedLinearRegressionModel model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        String family;
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        switch (family = model.getFamily()) {
            case "binomial": {
                return MiningFunction.CLASSIFICATION;
            }
        }
        return MiningFunction.REGRESSION;
    }

    public GeneralRegressionModel encodeModel(Schema schema) {
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        String targetCategory = null;
        List targetCategories = schema.getTargetCategories();
        if (targetCategories != null && targetCategories.size() > 0) {
            if (targetCategories.size() != 2) {
                throw new IllegalArgumentException();
            }
            targetCategory = (String)targetCategories.get(1);
        }
        MiningFunction miningFunction = targetCategory != null ? MiningFunction.CLASSIFICATION : MiningFunction.REGRESSION;
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema((Schema)schema), null, null, null).setDistribution(GeneralizedLinearRegressionModelConverter.parseFamily(model.getFamily())).setLinkFunction(GeneralizedLinearRegressionModelConverter.parseLinkFunction(model.getLink())).setLinkParameter(GeneralizedLinearRegressionModelConverter.parseLinkParameter(model.getLink()));
        GeneralRegressionModelUtil.encodeRegressionTable((GeneralRegressionModel)generalRegressionModel, (List)schema.getFeatures(), (Double)model.intercept(), VectorUtil.toList(model.coefficients()), (String)targetCategory);
        switch (miningFunction) {
            case CLASSIFICATION: {
                generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
                break;
            }
        }
        return generalRegressionModel;
    }

    private static GeneralRegressionModel.Distribution parseFamily(String family) {
        switch (family) {
            case "binomial": {
                return GeneralRegressionModel.Distribution.BINOMIAL;
            }
            case "gamma": {
                return GeneralRegressionModel.Distribution.GAMMA;
            }
            case "gaussian": {
                return GeneralRegressionModel.Distribution.NORMAL;
            }
            case "poisson": {
                return GeneralRegressionModel.Distribution.POISSON;
            }
        }
        throw new IllegalArgumentException(family);
    }

    private static GeneralRegressionModel.LinkFunction parseLinkFunction(String link) {
        switch (link) {
            case "cloglog": {
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            }
            case "identity": {
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            }
            case "inverse": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
            case "log": {
                return GeneralRegressionModel.LinkFunction.LOG;
            }
            case "logit": {
                return GeneralRegressionModel.LinkFunction.LOGIT;
            }
            case "probit": {
                return GeneralRegressionModel.LinkFunction.PROBIT;
            }
            case "sqrt": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
        }
        throw new IllegalArgumentException(link);
    }

    private static Double parseLinkParameter(String link) {
        switch (link) {
            case "inverse": {
                return -1.0;
            }
            case "sqrt": {
                return 0.5;
            }
        }
        return null;
    }
}

