/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.feature.ImputerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public class ImputerModelConverter
extends FeatureConverter<ImputerModel> {
    public ImputerModelConverter(ImputerModel transformer) {
        super(transformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        String[] outputCols;
        ImputerModel transformer = (ImputerModel)this.getTransformer();
        Double missingValue = transformer.getMissingValue();
        String strategy = transformer.getStrategy();
        Dataset surrogateDF = transformer.surrogateDF();
        String[] inputCols = transformer.getInputCols();
        if (inputCols.length != (outputCols = transformer.getOutputCols()).length) {
            throw new IllegalArgumentException();
        }
        MissingValueTreatmentMethod missingValueTreatmentMethod = ImputerModelConverter.parseStrategy(strategy);
        List surrogateRows = surrogateDF.collectAsList();
        if (surrogateRows.size() != 1) {
            throw new IllegalArgumentException();
        }
        Row surrogateRow = (Row)surrogateRows.get(0);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < inputCols.length; ++i) {
            MissingValueDecorator missingValueDecorator;
            String inputCol = inputCols[i];
            String outputCol = outputCols[i];
            Feature feature = encoder.getOnlyFeature(inputCol);
            Field field = encoder.getField(feature.getName());
            if (field instanceof DataField) {
                DataField dataField = (DataField)field;
                Object surrogate = surrogateRow.getAs(inputCol);
                missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue((Object)surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
                if (missingValue != null && !missingValue.isNaN()) {
                    missingValueDecorator.addValues(new String[]{ValueUtil.formatValue((Number)missingValue)});
                }
            } else {
                throw new IllegalArgumentException();
            }
            encoder.addDecorator(feature.getName(), (Decorator)missingValueDecorator);
            result.add(feature);
        }
        return result;
    }

    @Override
    public void registerFeatures(SparkMLEncoder encoder) {
        ImputerModel transformer = (ImputerModel)this.getTransformer();
        List<Feature> features = this.encodeFeatures(encoder);
        String[] outputCols = transformer.getOutputCols();
        if (outputCols.length != features.size()) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < features.size(); ++i) {
            String outputCol = outputCols[i];
            Feature feature = features.get(i);
            encoder.putFeatures(outputCol, Collections.singletonList(feature));
        }
    }

    public static MissingValueTreatmentMethod parseStrategy(String strategy) {
        switch (strategy) {
            case "mean": {
                return MissingValueTreatmentMethod.AS_MEAN;
            }
            case "median": {
                return MissingValueTreatmentMethod.AS_MEDIAN;
            }
        }
        throw new IllegalArgumentException(strategy);
    }
}

