/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neighbors;

import java.util.LinkedHashMap;
import java.util.List;
import numpy.core.ScalarUtil;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Estimator;
import sklearn.neighbors.HasNeighbors;
import sklearn.neighbors.HasTrainingData;

public class KNeighborsUtil {
    private KNeighborsUtil() {
    }

    public static <E extends Estimator> int getNumberOfNeighbors(E estimator) {
        Object nNeighbors = ScalarUtil.decode(estimator.get("n_neighbors"));
        return ValueUtil.asInt((Number)((Number)nNeighbors));
    }

    public static <E extends Estimator & HasTrainingData> NearestNeighborModel encodeNeighbors(E estimator, MiningFunction miningFunction, int numberOfInstances, int numberOfFeatures, Schema schema) {
        String weights = ((HasNeighbors)estimator).getWeights();
        if (!weights.equals("uniform")) {
            throw new IllegalArgumentException(weights);
        }
        List<?> y = ((HasTrainingData)estimator).getY();
        List<? extends Number> fitX = ((HasTrainingData)estimator).getFitX();
        ClassDictUtil.checkSize(numberOfInstances, y);
        LinkedHashMap<String, List> data = new LinkedHashMap<String, List>();
        InstanceFields instanceFields = new InstanceFields();
        Label label = schema.getLabel();
        if (label != null) {
            InstanceField instanceField = new InstanceField(label.getName()).setColumn("data:y");
            instanceFields.addInstanceFields(new InstanceField[]{instanceField});
            data.put(instanceField.getColumn(), y);
        }
        KNNInputs knnInputs = new KNNInputs();
        List features = schema.getFeatures();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            ContinuousFeature continuousFeature = feature.toContinuousFeature(estimator.getDataType());
            FieldName name = continuousFeature.getName();
            InstanceField instanceField = new InstanceField(name).setColumn("data:x" + String.valueOf(i + 1));
            instanceFields.addInstanceFields(new InstanceField[]{instanceField});
            KNNInput knnInput = new KNNInput(name);
            knnInputs.addKNNInputs(new KNNInput[]{knnInput});
            data.put(instanceField.getColumn(), CMatrixUtil.getColumn(fitX, (int)numberOfInstances, (int)numberOfFeatures, (int)i));
        }
        TrainingInstances trainingInstances = new TrainingInstances(instanceFields).setInlineTable(PMMLUtil.createInlineTable(data)).setTransformed(Boolean.valueOf(true));
        ComparisonMeasure comparisonMeasure = KNeighborsUtil.encodeComparisonMeasure(((HasNeighbors)estimator).getMetric(), ((HasNeighbors)estimator).getP());
        int numberOfNeighbors = ((HasNeighbors)estimator).getNumberOfNeighbors();
        Output output = new Output();
        for (int i = 0; i < numberOfNeighbors; ++i) {
            int rank = i + 1;
            OutputField outputField = new OutputField(FieldName.create((String)("neighbor(" + rank + ")")), DataType.STRING).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.ENTITY_ID).setRank(Integer.valueOf(rank));
            output.addOutputFields(new OutputField[]{outputField});
        }
        NearestNeighborModel nearestNeighborModel = new NearestNeighborModel(MiningFunction.REGRESSION, numberOfNeighbors, ModelUtil.createMiningSchema((Label)schema.getLabel()), trainingInstances, comparisonMeasure, knnInputs).setOutput(output);
        return nearestNeighborModel;
    }

    private static ComparisonMeasure encodeComparisonMeasure(String metric, int p) {
        switch (metric) {
            case "minkowski": {
                CityBlock measure;
                switch (p) {
                    case 1: {
                        measure = new CityBlock();
                        break;
                    }
                    case 2: {
                        measure = new Euclidean();
                        break;
                    }
                    default: {
                        measure = new Minkowski((double)p);
                    }
                }
                ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure((Measure)measure);
                return comparisonMeasure;
            }
        }
        throw new IllegalArgumentException(metric);
    }
}

