/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.general_regression.CovariateList;
import org.dmg.pmml.general_regression.FactorList;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.general_regression.PPMatrix;
import org.dmg.pmml.general_regression.ParamMatrix;
import org.dmg.pmml.general_regression.Parameter;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.general_regression.Predictor;
import org.dmg.pmml.general_regression.PredictorList;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.InteractionFeature;
import org.jpmml.sparkml.RegressionModelConverter;

public class GeneralizedLinearRegressionModelConverter
extends RegressionModelConverter<GeneralizedLinearRegressionModel> {
    public GeneralizedLinearRegressionModelConverter(GeneralizedLinearRegressionModel model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        String family;
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        switch (family = model.getFamily()) {
            case "binomial": {
                return MiningFunction.CLASSIFICATION;
            }
        }
        return MiningFunction.REGRESSION;
    }

    public GeneralRegressionModel encodeModel(Schema schema) {
        GeneralizedLinearRegressionModel model = (GeneralizedLinearRegressionModel)this.getTransformer();
        double intercept = model.intercept();
        Vector coefficients = model.coefficients();
        List features = schema.getFeatures();
        if (features.size() != coefficients.size()) {
            throw new IllegalArgumentException();
        }
        String targetCategory = null;
        List targetCategories = schema.getTargetCategories();
        if (targetCategories != null && targetCategories.size() > 0) {
            if (targetCategories.size() != 2) {
                throw new IllegalArgumentException();
            }
            targetCategory = (String)targetCategories.get(1);
        }
        ParameterList parameterList = new ParameterList();
        PPMatrix ppMatrix = new PPMatrix();
        ParamMatrix paramMatrix = new ParamMatrix();
        if (!ValueUtil.isZero((Number)intercept)) {
            Parameter parameter = new Parameter("p0").setLabel("(intercept)");
            parameterList.addParameters(new Parameter[]{parameter});
            PCell pCell = new PCell(parameter.getName(), intercept).setTargetCategory(targetCategory);
            paramMatrix.addPCells(new PCell[]{pCell});
        }
        LinkedHashSet<FieldName> covariates = new LinkedHashSet<FieldName>();
        LinkedHashSet<FieldName> factors = new LinkedHashSet<FieldName>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            Parameter parameter = new Parameter("p" + String.valueOf(i + 1));
            parameterList.addParameters(new Parameter[]{parameter});
            List<PPCell> ppCells = GeneralizedLinearRegressionModelConverter.createPPCells(parameter, feature, covariates, factors);
            ppMatrix.addPPCells(ppCells.toArray(new PPCell[ppCells.size()]));
            PCell pCell = new PCell(parameter.getName(), coefficients.apply(i)).setTargetCategory(targetCategory);
            paramMatrix.addPCells(new PCell[]{pCell});
        }
        MiningFunction miningFunction = targetCategory != null ? MiningFunction.CLASSIFICATION : MiningFunction.REGRESSION;
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema((Schema)schema), parameterList, ppMatrix, paramMatrix).setDistribution(GeneralizedLinearRegressionModelConverter.parseFamily(model.getFamily())).setLinkFunction(GeneralizedLinearRegressionModelConverter.parseLink(model.getLink())).setCovariateList(GeneralizedLinearRegressionModelConverter.createPredictorList(new CovariateList(), covariates)).setFactorList(GeneralizedLinearRegressionModelConverter.createPredictorList(new FactorList(), factors));
        switch (miningFunction) {
            case CLASSIFICATION: {
                generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
                break;
            }
        }
        return generalRegressionModel;
    }

    private static List<PPCell> createPPCells(Parameter parameter, Feature feature, Set<FieldName> covariates, Set<FieldName> factors) {
        if (feature instanceof InteractionFeature) {
            InteractionFeature interactionFeature = (InteractionFeature)feature;
            ArrayList<PPCell> ppCells = new ArrayList<PPCell>();
            List<Feature> inputFeatures = interactionFeature.getFeatures();
            for (Feature inputFeature : inputFeatures) {
                ppCells.addAll(GeneralizedLinearRegressionModelConverter.createPPCells(parameter, inputFeature, covariates, factors));
            }
            return ppCells;
        }
        if (feature instanceof ContinuousFeature) {
            ContinuousFeature continuousFeature = (ContinuousFeature)feature;
            covariates.add(continuousFeature.getName());
            PPCell ppCell = new PPCell("1", continuousFeature.getName(), parameter.getName());
            return Collections.singletonList(ppCell);
        }
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = (BinaryFeature)feature;
            factors.add(binaryFeature.getName());
            PPCell ppCell = new PPCell(binaryFeature.getValue(), binaryFeature.getName(), parameter.getName());
            return Collections.singletonList(ppCell);
        }
        throw new IllegalArgumentException();
    }

    private static GeneralRegressionModel.Distribution parseFamily(String family) {
        switch (family) {
            case "binomial": {
                return GeneralRegressionModel.Distribution.BINOMIAL;
            }
            case "gamma": {
                return GeneralRegressionModel.Distribution.GAMMA;
            }
            case "gaussian": {
                return GeneralRegressionModel.Distribution.NORMAL;
            }
            case "poisson": {
                return GeneralRegressionModel.Distribution.POISSON;
            }
        }
        throw new IllegalArgumentException(family);
    }

    private static GeneralRegressionModel.LinkFunction parseLink(String link) {
        switch (link) {
            case "cloglog": {
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            }
            case "identity": {
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            }
            case "log": {
                return GeneralRegressionModel.LinkFunction.LOG;
            }
            case "logit": {
                return GeneralRegressionModel.LinkFunction.LOGIT;
            }
            case "probit": {
                return GeneralRegressionModel.LinkFunction.PROBIT;
            }
        }
        throw new IllegalArgumentException(link);
    }

    private static <L extends PredictorList> L createPredictorList(L predictorList, Set<FieldName> names) {
        if (names.isEmpty()) {
            return null;
        }
        List predictors = predictorList.getPredictors();
        for (FieldName name : names) {
            Predictor predictor = new Predictor(name);
            predictors.add(predictor);
        }
        return predictorList;
    }
}

