/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neighbors;

import java.util.ArrayList;
import java.util.List;
import javax.xml.parsers.DocumentBuilder;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Row;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Estimator;
import sklearn.neighbors.HasNeighbors;
import sklearn.neighbors.HasTrainingData;

public class KNeighborsUtil {
    private KNeighborsUtil() {
    }

    public static <E extends Estimator & HasTrainingData> NearestNeighborModel encodeNeighbors(E estimator, MiningFunction miningFunction, int numberOfInstances, int numberOfFeatures, Schema schema) {
        ArrayList<String> keys = new ArrayList<String>();
        InstanceFields instanceFields = new InstanceFields();
        KNNInputs knnInputs = new KNNInputs();
        FieldName targetField = schema.getTargetField();
        if (targetField != null) {
            InstanceField instanceField = new InstanceField(targetField).setColumn("y");
            instanceFields.addInstanceFields(new InstanceField[]{instanceField});
            keys.add(instanceField.getColumn());
        }
        List features = schema.getFeatures();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            FieldName name = feature.getName();
            InstanceField instanceField = new InstanceField(name).setColumn("x" + String.valueOf(i + 1));
            instanceFields.addInstanceFields(new InstanceField[]{instanceField});
            keys.add(instanceField.getColumn());
            KNNInput knnInput = new KNNInput(name);
            knnInputs.addKNNInputs(new KNNInput[]{knnInput});
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        InlineTable inlineTable = new InlineTable();
        List<?> y = ((HasTrainingData)estimator).getY();
        if (y.size() != numberOfInstances) {
            throw new IllegalArgumentException();
        }
        List<? extends Number> fitX = ((HasTrainingData)estimator).getFitX();
        for (int i = 0; i < numberOfInstances; ++i) {
            ArrayList<Object> values = new ArrayList<Object>(1 + numberOfFeatures);
            values.add(y.get(i));
            values.addAll(MatrixUtil.getRow(fitX, numberOfInstances, numberOfFeatures, i));
            Row row = DOMUtil.createRow((DocumentBuilder)documentBuilder, keys, values);
            inlineTable.addRows(new Row[]{row});
        }
        TrainingInstances trainingInstances = new TrainingInstances(instanceFields).setInlineTable(inlineTable).setTransformed(Boolean.valueOf(true));
        ComparisonMeasure comparisonMeasure = KNeighborsUtil.encodeComparisonMeasure(((HasNeighbors)estimator).getMetric(), ((HasNeighbors)estimator).getP());
        String weights = ((HasNeighbors)estimator).getWeights();
        if (!weights.equals("uniform")) {
            throw new IllegalArgumentException(weights);
        }
        int numberOfNeighbors = ((HasNeighbors)estimator).getNumberOfNeighbors();
        ArrayList<OutputField> outputFields = new ArrayList<OutputField>(numberOfNeighbors);
        for (int i = 0; i < numberOfNeighbors; ++i) {
            int rank = i + 1;
            OutputField outputField = new OutputField(FieldName.create((String)("neighbor_" + rank)), DataType.STRING).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.ENTITY_ID).setRank(Integer.valueOf(rank));
            outputFields.add(outputField);
        }
        Output output = new Output(outputFields);
        NearestNeighborModel nearestNeighborModel = new NearestNeighborModel(MiningFunction.REGRESSION, numberOfNeighbors, ModelUtil.createMiningSchema((Schema)schema), trainingInstances, comparisonMeasure, knnInputs).setOutput(output);
        return nearestNeighborModel;
    }

    private static ComparisonMeasure encodeComparisonMeasure(String metric, int p) {
        switch (metric) {
            case "minkowski": {
                CityBlock measure;
                switch (p) {
                    case 1: {
                        measure = new CityBlock();
                        break;
                    }
                    case 2: {
                        measure = new Euclidean();
                        break;
                    }
                    default: {
                        measure = new Minkowski((double)p);
                    }
                }
                ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure((Measure)measure);
                return comparisonMeasure;
            }
        }
        throw new IllegalArgumentException(metric);
    }
}

