/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.nearest_neighbor;

import com.google.common.base.Function;
import com.google.common.cache.Cache;
import com.google.common.collect.Collections2;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Distance;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Similarity;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.PMMLElements;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.evaluator.AffinityDistribution;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InlineTableUtil;
import org.jpmml.evaluator.InputFieldUtil;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvisibleFieldException;
import org.jpmml.evaluator.MeasureUtil;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TypeInfo;
import org.jpmml.evaluator.TypeInfos;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueAggregator;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.VoteAggregator;
import org.jpmml.model.visitors.ActiveFieldFinder;

public class NearestNeighborModelEvaluator
extends ModelEvaluator<NearestNeighborModel> {
    private Table<Integer, FieldName, FieldValue> trainingInstances = null;
    private Map<Integer, BitSet> instanceFlags = null;
    private Map<Integer, List<FieldValue>> instanceValues = null;
    private static final Cache<NearestNeighborModel, Table<Integer, FieldName, FieldValue>> trainingInstanceCache = CacheUtil.buildCache();
    private static final Cache<NearestNeighborModel, Map<Integer, BitSet>> instanceFlagCache = CacheUtil.buildCache();
    private static final Cache<NearestNeighborModel, Map<Integer, List<FieldValue>>> instanceValueCache = CacheUtil.buildCache();

    private NearestNeighborModelEvaluator() {
    }

    public NearestNeighborModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, NearestNeighborModel.class));
    }

    public NearestNeighborModelEvaluator(PMML pmml, NearestNeighborModel nearestNeighborModel) {
        super(pmml, nearestNeighborModel);
        ComparisonMeasure comparisoonMeasure = nearestNeighborModel.getComparisonMeasure();
        if (comparisoonMeasure == null) {
            throw new MissingElementException((PMMLObject)nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_COMPARISONMEASURE);
        }
        TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances();
        if (trainingInstances == null) {
            throw new MissingElementException((PMMLObject)nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_TRAININGINSTANCES);
        }
        InstanceFields instanceFields = trainingInstances.getInstanceFields();
        if (instanceFields == null) {
            throw new MissingElementException((PMMLObject)trainingInstances, PMMLElements.TRAININGINSTANCES_INSTANCEFIELDS);
        }
        if (!instanceFields.hasInstanceFields()) {
            throw new MissingElementException((PMMLObject)instanceFields, PMMLElements.INSTANCEFIELDS_INSTANCEFIELDS);
        }
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        if (knnInputs == null) {
            throw new MissingElementException((PMMLObject)nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_KNNINPUTS);
        }
        if (!knnInputs.hasKNNInputs()) {
            throw new MissingElementException((PMMLObject)knnInputs, PMMLElements.KNNINPUTS_KNNINPUTS);
        }
    }

    @Override
    public String getSummary() {
        return "k-Nearest neighbors model";
    }

    @Override
    public DataField getDefaultDataField() {
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: 
            case CLASSIFICATION: 
            case MIXED: {
                return null;
            }
        }
        return super.getDefaultDataField();
    }

    @Override
    protected <V extends Number> Map<FieldName, AffinityDistribution<V>> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context) {
        return this.evaluateMixed(valueFactory, context);
    }

    @Override
    protected <V extends Number> Map<FieldName, AffinityDistribution<V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext context) {
        return this.evaluateMixed(valueFactory, context);
    }

    @Override
    protected <V extends Number> Map<FieldName, AffinityDistribution<V>> evaluateMixed(ValueFactory<V> valueFactory, EvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        List<InstanceResult<V>> instanceResults = this.evaluateInstanceRows(valueFactory, context);
        Ordering ordering = Ordering.natural().reverse();
        List<InstanceResult<V>> nearestInstanceResults = ordering.sortedCopy(instanceResults);
        Integer numberOfNeighbors = nearestNeighborModel.getNumberOfNeighbors();
        if (numberOfNeighbors == null) {
            throw new MissingAttributeException((PMMLObject)nearestNeighborModel, org.dmg.pmml.nearest_neighbor.PMMLAttributes.NEARESTNEIGHBORMODEL_NUMBEROFNEIGHBORS);
        }
        nearestInstanceResults = nearestInstanceResults.subList(0, numberOfNeighbors);
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                return row.toString();
            }
        };
        FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable();
        if (instanceIdVariable != null) {
            function = this.createIdentifierResolver(instanceIdVariable, table);
        }
        LinkedHashMap<FieldName, AffinityDistribution<AffinityDistribution<V>>> results = new LinkedHashMap<FieldName, AffinityDistribution<AffinityDistribution<V>>>();
        List<TargetField> targetFields = this.getTargetFields();
        for (TargetField targetField : targetFields) {
            Object value;
            FieldName name = targetField.getFieldName();
            OpType opType = targetField.getOpType();
            switch (opType) {
                case CONTINUOUS: {
                    value = this.calculateContinuousTarget(valueFactory, name, nearestInstanceResults, table);
                    break;
                }
                case CATEGORICAL: {
                    value = this.calculateCategoricalTarget(valueFactory, name, nearestInstanceResults, table);
                    break;
                }
                default: {
                    throw new InvalidElementException((PMMLObject)nearestNeighborModel);
                }
            }
            value = TypeUtil.parseOrCast(targetField.getDataType(), value);
            AffinityDistribution<V> result = this.createAffinityDistribution(instanceResults, function, value);
            results.put(name, result);
        }
        return results;
    }

    @Override
    protected <V extends Number> Map<FieldName, AffinityDistribution<V>> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        List<InstanceResult<V>> instanceResults = this.evaluateInstanceRows(valueFactory, context);
        FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable();
        if (instanceIdVariable == null) {
            throw new MissingAttributeException((PMMLObject)nearestNeighborModel, org.dmg.pmml.nearest_neighbor.PMMLAttributes.NEARESTNEIGHBORMODEL_INSTANCEIDVARIABLE);
        }
        Function<Integer, String> function = this.createIdentifierResolver(instanceIdVariable, table);
        AffinityDistribution<V> result = this.createAffinityDistribution(instanceResults, function, null);
        return Collections.singletonMap(this.getTargetName(), result);
    }

    private <V extends Number> List<InstanceResult<V>> evaluateInstanceRows(ValueFactory<V> valueFactory, EvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure();
        ArrayList<FieldValue> values = new ArrayList<FieldValue>();
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldName name = knnInput.getField();
            if (name == null) {
                throw new MissingAttributeException((PMMLObject)knnInput, org.dmg.pmml.nearest_neighbor.PMMLAttributes.KNNINPUT_FIELD);
            }
            FieldValue value = context.evaluate(name);
            values.add(value);
        }
        Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure);
        if (measure instanceof Similarity) {
            return this.evaluateSimilarity(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values);
        }
        if (measure instanceof Distance) {
            return this.evaluateDistance(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values);
        }
        throw new UnsupportedElementException((PMMLObject)measure);
    }

    private <V extends Number> List<InstanceResult<V>> evaluateSimilarity(ValueFactory<V> valueFactory, ComparisonMeasure comparisonMeasure, List<KNNInput> knnInputs, List<FieldValue> values) {
        BitSet flags = MeasureUtil.toBitSet(values);
        Map<Integer, BitSet> flagMap = this.getInstanceFlags();
        ArrayList<InstanceResult<V>> result = new ArrayList<InstanceResult<V>>(flagMap.size());
        Set<Integer> rowKeys = flagMap.keySet();
        for (Integer rowKey : rowKeys) {
            BitSet instanceFlags = flagMap.get(rowKey);
            Value<V> similarity = MeasureUtil.evaluateSimilarity(valueFactory, comparisonMeasure, knnInputs, flags, instanceFlags);
            result.add(new InstanceResult.Similarity(rowKey, similarity));
        }
        return result;
    }

    private <V extends Number> List<InstanceResult<V>> evaluateDistance(ValueFactory<V> valueFactory, ComparisonMeasure comparisonMeasure, List<KNNInput> knnInputs, List<FieldValue> values) {
        Map<Integer, List<FieldValue>> valueMap = this.getInstanceValues();
        ArrayList<InstanceResult<V>> result = new ArrayList<InstanceResult<V>>(valueMap.size());
        Value<V> adjustment = MeasureUtil.calculateAdjustment(valueFactory, values);
        Set<Integer> rowKeys = valueMap.keySet();
        for (Integer rowKey : rowKeys) {
            List<FieldValue> instanceValues = valueMap.get(rowKey);
            Value<V> distance = MeasureUtil.evaluateDistance(valueFactory, comparisonMeasure, knnInputs, values, instanceValues, adjustment);
            result.add(new InstanceResult.Distance(rowKey, distance));
        }
        return result;
    }

    private <V extends Number> V calculateContinuousTarget(ValueFactory<V> valueFactory, FieldName name, List<InstanceResult<V>> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        ValueAggregator aggregator;
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Number threshold = nearestNeighborModel.getThreshold();
        NearestNeighborModel.ContinuousScoringMethod continuousScoringMethod = nearestNeighborModel.getContinuousScoringMethod();
        switch (continuousScoringMethod) {
            case AVERAGE: {
                aggregator = new ValueAggregator.UnivariateStatistic<V>(valueFactory);
                break;
            }
            case WEIGHTED_AVERAGE: {
                aggregator = new ValueAggregator.WeightedUnivariateStatistic<V>(valueFactory);
                break;
            }
            case MEDIAN: {
                aggregator = new ValueAggregator.Median<V>(valueFactory, instanceResults.size());
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)nearestNeighborModel, (Enum<?>)continuousScoringMethod);
            }
        }
        block14: for (InstanceResult<V> instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (FieldValueUtil.isMissing(value)) {
                throw new MissingValueException(name);
            }
            Number targetValue = value.asNumber();
            switch (continuousScoringMethod) {
                case AVERAGE: 
                case MEDIAN: {
                    aggregator.add(targetValue);
                    continue block14;
                }
                case WEIGHTED_AVERAGE: {
                    InstanceResult.Distance distance = TypeUtil.cast(InstanceResult.Distance.class, instanceResult);
                    Value weight = distance.getWeight(threshold);
                    aggregator.add(targetValue, (Number)weight.getValue());
                    continue block14;
                }
            }
            throw new UnsupportedAttributeException((PMMLObject)nearestNeighborModel, (Enum<?>)continuousScoringMethod);
        }
        switch (continuousScoringMethod) {
            case AVERAGE: {
                return aggregator.average().getValue();
            }
            case WEIGHTED_AVERAGE: {
                return aggregator.weightedAverage().getValue();
            }
            case MEDIAN: {
                return aggregator.median().getValue();
            }
        }
        throw new UnsupportedAttributeException((PMMLObject)nearestNeighborModel, (Enum<?>)continuousScoringMethod);
    }

    private <V extends Number> Object calculateCategoricalTarget(ValueFactory<V> valueFactory, FieldName name, List<InstanceResult<V>> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Number threshold = nearestNeighborModel.getThreshold();
        VoteAggregator<Object, V> aggregator = new VoteAggregator<Object, V>(valueFactory);
        NearestNeighborModel.CategoricalScoringMethod categoricalScoringMethod = nearestNeighborModel.getCategoricalScoringMethod();
        block4: for (InstanceResult<V> instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (FieldValueUtil.isMissing(value)) {
                throw new MissingValueException(name);
            }
            Object targetValue = value.getValue();
            switch (categoricalScoringMethod) {
                case MAJORITY_VOTE: {
                    aggregator.add(targetValue);
                    continue block4;
                }
                case WEIGHTED_MAJORITY_VOTE: {
                    InstanceResult.Distance distance = TypeUtil.cast(InstanceResult.Distance.class, instanceResult);
                    Value weight = distance.getWeight(threshold);
                    aggregator.add(targetValue, (Number)weight.getValue());
                    continue block4;
                }
            }
            throw new UnsupportedAttributeException((PMMLObject)nearestNeighborModel, (Enum<?>)categoricalScoringMethod);
        }
        Set winners = aggregator.getWinners();
        if (winners.size() > 1) {
            LinkedHashMultiset multiset = LinkedHashMultiset.create();
            Map column = table.column((Object)name);
            multiset.addAll(Collections2.transform(column.values(), FieldValue::getValue));
            aggregator.clear();
            for (Object winner : winners) {
                aggregator.add(winner, multiset.count(winner));
            }
            winners = aggregator.getWinners();
            if (winners.size() > 1) {
                return Collections.min(winners);
            }
        }
        return Iterables.getFirst(winners, null);
    }

    private Function<Integer, String> createIdentifierResolver(final FieldName name, final Table<Integer, FieldName, FieldValue> table) {
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                FieldValue value = (FieldValue)table.get((Object)row, (Object)name);
                if (FieldValueUtil.isMissing(value)) {
                    throw new MissingValueException(name);
                }
                return value.asString();
            }
        };
        return function;
    }

    private <V extends Number> AffinityDistribution<V> createAffinityDistribution(List<InstanceResult<V>> instanceResults, Function<Integer, String> function, Object result) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure();
        ValueMap<Object, Value<V>> values = new ValueMap<Object, Value<V>>(2 * instanceResults.size());
        for (InstanceResult<V> instanceResult : instanceResults) {
            values.put(function.apply((Object)instanceResult.getId()), instanceResult.getValue());
        }
        Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure);
        if (measure instanceof Similarity) {
            return new AffinityDistribution(Classification.Type.SIMILARITY, values, result);
        }
        if (measure instanceof Distance) {
            return new AffinityDistribution(Classification.Type.DISTANCE, values, result);
        }
        throw new UnsupportedElementException((PMMLObject)measure);
    }

    private Table<Integer, FieldName, FieldValue> getTrainingInstances() {
        if (this.trainingInstances == null) {
            this.trainingInstances = this.getValue(trainingInstanceCache, NearestNeighborModelEvaluator.createTrainingInstanceLoader(this));
        }
        return this.trainingInstances;
    }

    private static Callable<Table<Integer, FieldName, FieldValue>> createTrainingInstanceLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Table<Integer, FieldName, FieldValue>>(){

            @Override
            public Table<Integer, FieldName, FieldValue> call() {
                return ImmutableTable.copyOf((Table)NearestNeighborModelEvaluator.parseTrainingInstances(modelEvaluator));
            }
        };
    }

    private static Table<Integer, FieldName, FieldValue> parseTrainingInstances(NearestNeighborModelEvaluator modelEvaluator) {
        Iterator field;
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)modelEvaluator.getModel();
        FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable();
        HashSet<FieldName> names = new HashSet<FieldName>();
        names.addAll(ActiveFieldFinder.getFieldNames((PMMLObject[])new PMMLObject[]{nearestNeighborModel}));
        List<TargetField> targetFields = modelEvaluator.getTargetFields();
        for (TargetField targetField : targetFields) {
            names.add(targetField.getFieldName());
        }
        TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances();
        ArrayList<FieldLoader> fieldLoaders = new ArrayList<FieldLoader>();
        InstanceFields instanceFields = trainingInstances.getInstanceFields();
        for (InstanceField instanceField : instanceFields) {
            FieldName name = instanceField.getField();
            if (name == null) {
                throw new MissingAttributeException((PMMLObject)instanceField, org.dmg.pmml.nearest_neighbor.PMMLAttributes.INSTANCEFIELD_FIELD);
            }
            String column = instanceField.getColumn();
            if (instanceIdVariable != null && instanceIdVariable.equals((Object)name)) {
                fieldLoaders.add(new IdentifierLoader(name, column));
                continue;
            }
            if (!names.contains(name)) continue;
            field = modelEvaluator.resolveField(name);
            if (field == null) {
                throw new MissingFieldException(name, (PMMLObject)instanceField);
            }
            if (field instanceof DataField) {
                DataField dataField = (DataField)field;
                MiningField miningField = modelEvaluator.getMiningField(name);
                if (miningField == null) {
                    throw new InvisibleFieldException(name, (PMMLObject)instanceField);
                }
                fieldLoaders.add(new DataFieldLoader(name, column, dataField, miningField));
                continue;
            }
            if (field instanceof DerivedField) {
                DerivedField derivedField = (DerivedField)field;
                boolean inherited = modelEvaluator.getDerivedField(name) == null && modelEvaluator.getLocalDerivedField(name) == null;
                MiningField miningField = modelEvaluator.getMiningField(name);
                if (miningField == null && inherited) {
                    throw new InvisibleFieldException(name, (PMMLObject)instanceField);
                }
                fieldLoaders.add(new DerivedFieldLoader(name, column, derivedField, miningField));
                continue;
            }
            throw new InvalidAttributeException((PMMLObject)instanceField, org.dmg.pmml.nearest_neighbor.PMMLAttributes.INSTANCEFIELD_FIELD, name);
        }
        HashBasedTable result = HashBasedTable.create();
        InlineTable inlineTable = InlineTableUtil.getInlineTable(trainingInstances);
        if (inlineTable != null) {
            Table<Integer, String, Object> table = InlineTableUtil.getContent(inlineTable);
            Set rowKeys = table.rowKeySet();
            field = rowKeys.iterator();
            while (field.hasNext()) {
                Integer rowKey = (Integer)field.next();
                Map rowValues = table.row((Object)rowKey);
                for (FieldLoader fieldLoader : fieldLoaders) {
                    result.put((Object)rowKey, (Object)fieldLoader.getName(), (Object)fieldLoader.load(rowValues));
                }
            }
        }
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldName name = knnInput.getField();
            Field<?> field2 = modelEvaluator.resolveField(name);
            if (!(field2 instanceof DerivedField)) continue;
            DerivedField derivedField = (DerivedField)field2;
            Set rowKeys = result.rowKeySet();
            for (Integer rowKey : rowKeys) {
                Map rowValues = result.row((Object)rowKey);
                if (rowValues.containsKey(name)) continue;
                ModelEvaluationContext context = modelEvaluator.createEvaluationContext();
                context.declareAll(rowValues);
                FieldValue value = ExpressionUtil.evaluate(derivedField, (EvaluationContext)context);
                result.put((Object)rowKey, (Object)name, (Object)value);
            }
        }
        Integer numberOfNeighbors = nearestNeighborModel.getNumberOfNeighbors();
        if (numberOfNeighbors == null) {
            throw new MissingAttributeException((PMMLObject)nearestNeighborModel, org.dmg.pmml.nearest_neighbor.PMMLAttributes.NEARESTNEIGHBORMODEL_NUMBEROFNEIGHBORS);
        }
        if (numberOfNeighbors < 0 || result.size() < numberOfNeighbors) {
            throw new InvalidAttributeException((PMMLObject)nearestNeighborModel, org.dmg.pmml.nearest_neighbor.PMMLAttributes.NEARESTNEIGHBORMODEL_NUMBEROFNEIGHBORS, numberOfNeighbors);
        }
        return result;
    }

    private Map<Integer, BitSet> getInstanceFlags() {
        if (this.instanceFlags == null) {
            this.instanceFlags = this.getValue(instanceFlagCache, NearestNeighborModelEvaluator.createInstanceFlagLoader(this));
        }
        return this.instanceFlags;
    }

    private static Callable<Map<Integer, BitSet>> createInstanceFlagLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Map<Integer, BitSet>>(){

            @Override
            public Map<Integer, BitSet> call() {
                return ImmutableMap.copyOf((Map)NearestNeighborModelEvaluator.loadInstanceFlags(modelEvaluator));
            }
        };
    }

    private static Map<Integer, BitSet> loadInstanceFlags(NearestNeighborModelEvaluator modelEvaluator) {
        Map<Integer, List<FieldValue>> instanceValues = modelEvaluator.getValue(instanceValueCache, NearestNeighborModelEvaluator.createInstanceValueLoader(modelEvaluator));
        Map<Integer, BitSet> instanceFlags = instanceValues.entrySet().stream().collect(Collectors.toMap(entry -> (Integer)entry.getKey(), entry -> MeasureUtil.toBitSet((List)entry.getValue())));
        return instanceFlags;
    }

    private Map<Integer, List<FieldValue>> getInstanceValues() {
        if (this.instanceValues == null) {
            this.instanceValues = this.getValue(instanceValueCache, NearestNeighborModelEvaluator.createInstanceValueLoader(this));
        }
        return this.instanceValues;
    }

    private static Callable<Map<Integer, List<FieldValue>>> createInstanceValueLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Map<Integer, List<FieldValue>>>(){

            @Override
            public Map<Integer, List<FieldValue>> call() {
                Map<Integer, List> instanceValues = NearestNeighborModelEvaluator.loadInstanceValues(modelEvaluator);
                instanceValues = instanceValues.entrySet().stream().collect(Collectors.toMap(entry -> (Integer)entry.getKey(), entry -> ImmutableList.copyOf((Collection)((Collection)entry.getValue()))));
                return ImmutableMap.copyOf(instanceValues);
            }
        };
    }

    private static Map<Integer, List<FieldValue>> loadInstanceValues(NearestNeighborModelEvaluator modelEvaluator) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)modelEvaluator.getModel();
        LinkedHashMap<Integer, List<FieldValue>> result = new LinkedHashMap<Integer, List<FieldValue>>();
        Table<Integer, FieldName, FieldValue> table = modelEvaluator.getValue(trainingInstanceCache, NearestNeighborModelEvaluator.createTrainingInstanceLoader(modelEvaluator));
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        ImmutableSortedSet rowKeys = ImmutableSortedSet.copyOf((Collection)table.rowKeySet());
        for (Integer rowKey : rowKeys) {
            ArrayList<FieldValue> values = new ArrayList<FieldValue>();
            Map rowValues = table.row((Object)rowKey);
            for (KNNInput knnInput : knnInputs) {
                FieldValue value = (FieldValue)rowValues.get(knnInput.getField());
                values.add(value);
            }
            result.put(rowKey, values);
        }
        return result;
    }

    private static abstract class InstanceResult<V extends Number>
    implements Comparable<InstanceResult<V>> {
        private Integer id = null;
        private Value<V> value = null;

        private InstanceResult(Integer id, Value<V> value) {
            this.setId(id);
            this.setValue(value);
        }

        public Integer getId() {
            return this.id;
        }

        private void setId(Integer id) {
            this.id = id;
        }

        public Value<V> getValue() {
            return this.value;
        }

        private void setValue(Value<V> value) {
            this.value = value;
        }

        private static class Distance<V extends Number>
        extends InstanceResult<V> {
            private Distance(Integer id, Value<V> value) {
                super(id, value);
            }

            @Override
            public int compareTo(InstanceResult<V> that) {
                if (that instanceof Distance) {
                    return Classification.Type.DISTANCE.compareValues(this.getValue(), that.getValue());
                }
                throw new ClassCastException();
            }

            public Value<V> getWeight(Number threshold) {
                Value value = this.getValue();
                value = value.copy();
                value.add(threshold).reciprocal();
                return value;
            }
        }

        private static class Similarity<V extends Number>
        extends InstanceResult<V> {
            private Similarity(Integer id, Value<V> value) {
                super(id, value);
            }

            @Override
            public int compareTo(InstanceResult<V> that) {
                if (that instanceof Similarity) {
                    return Classification.Type.SIMILARITY.compareValues(this.getValue(), that.getValue());
                }
                throw new ClassCastException();
            }
        }
    }

    private static class DerivedFieldLoader
    extends FieldLoader {
        private DerivedField derivedField = null;
        private MiningField miningField = null;

        private DerivedFieldLoader(FieldName name, String column, DerivedField derivedField, MiningField miningField) {
            super(name, column);
            this.setDerivedField(derivedField);
            this.setMiningField(miningField);
        }

        @Override
        public FieldValue prepare(Object value) {
            final DerivedField derivedField = this.getDerivedField();
            MiningField miningField = this.getMiningField();
            if (miningField != null) {
                return InputFieldUtil.prepareInputValue(derivedField, miningField, value);
            }
            TypeInfo typeInfo = new TypeInfo(){

                @Override
                public DataType getDataType() {
                    DataType dataType = derivedField.getDataType();
                    if (dataType == null) {
                        throw new MissingAttributeException((PMMLObject)derivedField, PMMLAttributes.DERIVEDFIELD_DATATYPE);
                    }
                    return dataType;
                }

                @Override
                public OpType getOpType() {
                    OpType opType = derivedField.getOpType();
                    if (opType == null) {
                        throw new MissingAttributeException((PMMLObject)derivedField, PMMLAttributes.DERIVEDFIELD_OPTYPE);
                    }
                    return opType;
                }

                @Override
                public List<?> getOrdering() {
                    List<?> ordering = FieldUtil.getValidValues(derivedField);
                    return ordering;
                }
            };
            return FieldValueUtil.create(typeInfo, value);
        }

        public DerivedField getDerivedField() {
            return this.derivedField;
        }

        private void setDerivedField(DerivedField derivedField) {
            this.derivedField = derivedField;
        }

        public MiningField getMiningField() {
            return this.miningField;
        }

        private void setMiningField(MiningField miningField) {
            this.miningField = miningField;
        }
    }

    private static class DataFieldLoader
    extends FieldLoader {
        private DataField dataField = null;
        private MiningField miningField = null;

        private DataFieldLoader(FieldName name, String column, DataField dataField, MiningField miningField) {
            super(name, column);
            this.setDataField(dataField);
            this.setMiningField(miningField);
        }

        @Override
        public FieldValue prepare(Object value) {
            return InputFieldUtil.prepareInputValue(this.getDataField(), this.getMiningField(), value);
        }

        public DataField getDataField() {
            return this.dataField;
        }

        private void setDataField(DataField dataField) {
            this.dataField = dataField;
        }

        public MiningField getMiningField() {
            return this.miningField;
        }

        private void setMiningField(MiningField miningField) {
            this.miningField = miningField;
        }
    }

    private static class IdentifierLoader
    extends FieldLoader {
        private IdentifierLoader(FieldName name, String column) {
            super(name, column);
        }

        @Override
        public FieldValue prepare(Object value) {
            return FieldValueUtil.create(TypeInfos.CATEGORICAL_STRING, value);
        }
    }

    private static abstract class FieldLoader {
        private FieldName name = null;
        private String column = null;

        private FieldLoader(FieldName name, String column) {
            this.setName(name);
            this.setColumn(column);
        }

        public abstract FieldValue prepare(Object var1);

        public FieldValue load(Map<String, Object> values) {
            Object value = values.get(this.getColumn());
            return this.prepare(value);
        }

        public FieldName getName() {
            return this.name;
        }

        private void setName(FieldName name) {
            this.name = name;
        }

        public String getColumn() {
            return this.column;
        }

        private void setColumn(String column) {
            this.column = column;
        }
    }
}

