/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.feature.ImputerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.dmg.pmml.Value;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.MultiFeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public class ImputerModelConverter
extends MultiFeatureConverter<ImputerModel> {
    public ImputerModelConverter(ImputerModel transformer) {
        super(transformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        String[] inputCols;
        ImputerModel transformer = (ImputerModel)this.getTransformer();
        Double missingValue = transformer.getMissingValue();
        String strategy = transformer.getStrategy();
        Dataset surrogateDF = transformer.surrogateDF();
        MissingValueTreatmentMethod missingValueTreatmentMethod = ImputerModelConverter.parseStrategy(strategy);
        List surrogateRows = surrogateDF.collectAsList();
        if (surrogateRows.size() != 1) {
            throw new IllegalArgumentException();
        }
        Row surrogateRow = (Row)surrogateRows.get(0);
        FeatureConverter.InOutMode inputMode = this.getInputMode();
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (String inputCol : inputCols = inputMode.getInputCols(transformer)) {
            Feature feature = encoder.getOnlyFeature(inputCol);
            Field field = feature.getField();
            if (field instanceof DataField) {
                DataField dataField = (DataField)field;
                Object surrogate = surrogateRow.getAs(inputCol);
                encoder.addDecorator(dataField, (Decorator)new MissingValueDecorator(missingValueTreatmentMethod, surrogate));
                if (missingValue != null && !missingValue.isNaN()) {
                    PMMLUtil.addValues((Field)dataField, Collections.singletonList(missingValue), (Value.Property)Value.Property.MISSING);
                }
            } else {
                throw new IllegalArgumentException();
            }
            result.add(feature);
        }
        return result;
    }

    public static MissingValueTreatmentMethod parseStrategy(String strategy) {
        switch (strategy) {
            case "mean": {
                return MissingValueTreatmentMethod.AS_MEAN;
            }
            case "median": {
                return MissingValueTreatmentMethod.AS_MEDIAN;
            }
        }
        throw new IllegalArgumentException(strategy);
    }
}

