/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.base.Function;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Collections2;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.dmg.pmml.CategoricalScoringMethodType;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.ContinuousScoringMethodType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.InstanceField;
import org.dmg.pmml.InstanceFields;
import org.dmg.pmml.KNNInput;
import org.dmg.pmml.KNNInputs;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NearestNeighborModel;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TableLocator;
import org.dmg.pmml.TrainingInstances;
import org.jpmml.evaluator.ArgumentUtil;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InlineTableUtil;
import org.jpmml.evaluator.InstanceClassificationMap;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MeasureUtil;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.RegressionAggregator;
import org.jpmml.evaluator.VoteAggregator;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class NearestNeighborModelEvaluator
extends ModelEvaluator<NearestNeighborModel> {
    private static final Cache<NearestNeighborModel, Table<Integer, FieldName, FieldValue>> trainingInstanceCache = CacheBuilder.newBuilder().weakKeys().build();
    private static final Cache<NearestNeighborModel, Map<Integer, BitSet>> instanceFlagCache = CacheBuilder.newBuilder().weakKeys().build();
    private static final Cache<NearestNeighborModel, Map<Integer, List<FieldValue>>> instanceValueCache = CacheBuilder.newBuilder().weakKeys().build();

    public NearestNeighborModelEvaluator(PMML pmml) {
        this(pmml, (NearestNeighborModel)NearestNeighborModelEvaluator.find((List)pmml.getModels(), NearestNeighborModel.class));
    }

    public NearestNeighborModelEvaluator(PMML pmml, NearestNeighborModel nearestNeighborModel) {
        super(pmml, nearestNeighborModel);
    }

    public String getSummary() {
        return "k-Nearest neighbors model";
    }

    @Override
    protected DataField getDataField() {
        return null;
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, InstanceClassificationMap> predictions;
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        if (!nearestNeighborModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)nearestNeighborModel);
        }
        MiningFunctionType miningFunction = nearestNeighborModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: 
            case CLASSIFICATION: 
            case MIXED: {
                predictions = this.evaluateMixed(context);
                break;
            }
            case CLUSTERING: {
                predictions = this.evaluateClustering(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, InstanceClassificationMap> evaluateMixed(ModelEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        List<InstanceResult> instanceResults = this.evaluateInstanceRows(context);
        Ordering ordering = Ordering.natural().reverse();
        List<InstanceResult> nearestInstanceResults = ordering.sortedCopy(instanceResults);
        nearestInstanceResults = nearestInstanceResults.subList(0, nearestNeighborModel.getNumberOfNeighbors());
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                return row.toString();
            }
        };
        String idField = nearestNeighborModel.getInstanceIdVariable();
        if (idField != null) {
            function = this.createIdentifierResolver(FieldName.create((String)idField), table);
        }
        LinkedHashMap result = Maps.newLinkedHashMap();
        List targetFields = this.getTargetFields();
        for (FieldName targetField : targetFields) {
            Object value;
            DataField dataField = this.getDataField(targetField);
            OpType opType = dataField.getOpType();
            switch (opType) {
                case CONTINUOUS: {
                    value = this.calculateContinuousTarget(targetField, nearestInstanceResults, table);
                    break;
                }
                case CATEGORICAL: {
                    value = this.calculateCategoricalTarget(targetField, nearestInstanceResults, table);
                    break;
                }
                default: {
                    throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum)opType);
                }
            }
            result.put(targetField, this.createMeasureMap(value, instanceResults, function));
        }
        return result;
    }

    private Map<FieldName, InstanceClassificationMap> evaluateClustering(ModelEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        List<InstanceResult> instanceResults = this.evaluateInstanceRows(context);
        String idField = nearestNeighborModel.getInstanceIdVariable();
        if (idField == null) {
            throw new InvalidFeatureException((PMMLObject)nearestNeighborModel);
        }
        Function<Integer, String> function = this.createIdentifierResolver(FieldName.create((String)idField), table);
        return Collections.singletonMap(this.getTargetField(), this.createMeasureMap(null, instanceResults, function));
    }

    private List<InstanceResult> evaluateInstanceRows(ModelEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        ArrayList values = Lists.newArrayList();
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldValue value = ExpressionUtil.evaluate(knnInput.getField(), (EvaluationContext)context);
            values.add(value);
        }
        ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure();
        Measure measure = comparisonMeasure.getMeasure();
        if (MeasureUtil.isSimilarity(measure)) {
            return this.evaluateSimilarity(comparisonMeasure, knnInputs.getKNNInputs(), values);
        }
        if (MeasureUtil.isDistance(measure)) {
            return this.evaluateDistance(comparisonMeasure, knnInputs.getKNNInputs(), values);
        }
        throw new UnsupportedFeatureException((PMMLObject)measure);
    }

    private List<InstanceResult> evaluateSimilarity(ComparisonMeasure comparisonMeasure, List<KNNInput> knnInputs, List<FieldValue> values) {
        ArrayList result = Lists.newArrayList();
        BitSet flags = MeasureUtil.toBitSet(values);
        Map<Integer, BitSet> flagMap = this.getInstanceFlags();
        Set<Integer> rowKeys = flagMap.keySet();
        for (Integer rowKey : rowKeys) {
            BitSet instanceFlags = flagMap.get(rowKey);
            Double similarity = MeasureUtil.evaluateSimilarity(comparisonMeasure, knnInputs, flags, instanceFlags);
            result.add(new InstanceResult.Similarity(rowKey, similarity));
        }
        return result;
    }

    private List<InstanceResult> evaluateDistance(ComparisonMeasure comparisonMeasure, List<KNNInput> knnInputs, List<FieldValue> values) {
        ArrayList result = Lists.newArrayList();
        Double adjustment = MeasureUtil.calculateAdjustment(values);
        Map<Integer, List<FieldValue>> valueMap = this.getInstanceValues();
        Set<Integer> rowKeys = valueMap.keySet();
        for (Integer rowKey : rowKeys) {
            List<FieldValue> instanceValues = valueMap.get(rowKey);
            Double distance = MeasureUtil.evaluateDistance(comparisonMeasure, knnInputs, values, instanceValues, adjustment);
            result.add(new InstanceResult.Distance(rowKey, distance));
        }
        return result;
    }

    private Double calculateContinuousTarget(FieldName name, List<InstanceResult> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        RegressionAggregator aggregator = new RegressionAggregator();
        double denominator = 0.0;
        ContinuousScoringMethodType continuousScoringMethod = nearestNeighborModel.getContinuousScoringMethod();
        block9: for (InstanceResult instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (value == null) {
                throw new MissingFieldException(name);
            }
            Number number = value.asNumber();
            switch (continuousScoringMethod) {
                case MEDIAN: {
                    aggregator.add(number.doubleValue());
                    continue block9;
                }
                case AVERAGE: {
                    aggregator.add(number.doubleValue());
                    denominator += 1.0;
                    continue block9;
                }
                case WEIGHTED_AVERAGE: {
                    double weight = instanceResult.getWeight(nearestNeighborModel.getThreshold());
                    aggregator.add(number.doubleValue() * weight);
                    denominator += weight;
                    continue block9;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)continuousScoringMethod);
        }
        switch (continuousScoringMethod) {
            case MEDIAN: {
                return aggregator.median();
            }
            case AVERAGE: 
            case WEIGHTED_AVERAGE: {
                return aggregator.average(denominator);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)continuousScoringMethod);
    }

    private Object calculateCategoricalTarget(FieldName name, List<InstanceResult> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        VoteAggregator<Object> aggregator = new VoteAggregator<Object>();
        CategoricalScoringMethodType categoricalScoringMethod = nearestNeighborModel.getCategoricalScoringMethod();
        block4: for (InstanceResult instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (value == null) {
                throw new MissingFieldException(name);
            }
            Object object = value.getValue();
            switch (categoricalScoringMethod) {
                case MAJORITY_VOTE: {
                    aggregator.add(object, 1.0);
                    continue block4;
                }
                case WEIGHTED_MAJORITY_VOTE: {
                    aggregator.add(object, instanceResult.getWeight(nearestNeighborModel.getThreshold()));
                    continue block4;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)categoricalScoringMethod);
        }
        Set winners = aggregator.getWinners();
        if (winners.size() > 1) {
            LinkedHashMultiset multiset = LinkedHashMultiset.create();
            Map column = table.column((Object)name);
            Function<FieldValue, Object> function = new Function<FieldValue, Object>(){

                public Object apply(FieldValue value) {
                    return value.getValue();
                }
            };
            multiset.addAll(Collections2.transform(column.values(), (Function)function));
            aggregator.clear();
            for (Object winner : winners) {
                aggregator.add(winner, Double.valueOf(multiset.count(winner)));
            }
            winners = aggregator.getWinners();
            if (winners.size() > 1) {
                return Collections.min(winners);
            }
        }
        return Iterables.getFirst(winners, null);
    }

    private Function<Integer, String> createIdentifierResolver(final FieldName name, final Table<Integer, FieldName, FieldValue> table) {
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                FieldValue value = (FieldValue)table.get((Object)row, (Object)name);
                if (value == null) {
                    throw new MissingFieldException(name);
                }
                return value.asString();
            }
        };
        return function;
    }

    private InstanceClassificationMap createMeasureMap(Object value, List<InstanceResult> instanceResults, Function<Integer, String> function) {
        InstanceClassificationMap result;
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure();
        Measure measure = comparisonMeasure.getMeasure();
        if (MeasureUtil.isSimilarity(measure)) {
            result = new InstanceClassificationMap(ClassificationMap.Type.SIMILARITY, value);
        } else if (MeasureUtil.isDistance(measure)) {
            result = new InstanceClassificationMap(ClassificationMap.Type.DISTANCE, value);
        } else {
            throw new UnsupportedFeatureException((PMMLObject)measure);
        }
        for (InstanceResult instanceResult : instanceResults) {
            result.put(function.apply((Object)instanceResult.getId()), instanceResult.getValue());
        }
        return result;
    }

    private Table<Integer, FieldName, FieldValue> getTrainingInstances() {
        return this.getValue(NearestNeighborModelEvaluator.createTrainingInstanceLoader(this), trainingInstanceCache);
    }

    private static Callable<Table<Integer, FieldName, FieldValue>> createTrainingInstanceLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Table<Integer, FieldName, FieldValue>>(){

            @Override
            public Table<Integer, FieldName, FieldValue> call() {
                return NearestNeighborModelEvaluator.parseTrainingInstances(modelEvaluator);
            }
        };
    }

    private static Table<Integer, FieldName, FieldValue> parseTrainingInstances(NearestNeighborModelEvaluator modelEvaluator) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)modelEvaluator.getModel();
        TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances();
        TableLocator tableLocator = trainingInstances.getTableLocator();
        if (tableLocator != null) {
            throw new UnsupportedFeatureException((PMMLObject)tableLocator);
        }
        String idField = nearestNeighborModel.getInstanceIdVariable();
        ArrayList fieldLoaders = Lists.newArrayList();
        InstanceFields instanceFields = trainingInstances.getInstanceFields();
        for (InstanceField instanceField : instanceFields) {
            String field = instanceField.getField();
            String column = instanceField.getColumn();
            FieldName name = FieldName.create((String)field);
            if (idField != null && idField.equals(field)) {
                fieldLoaders.add(new IdentifierLoader(name, column));
                continue;
            }
            DataField dataField = modelEvaluator.getDataField(name);
            MiningField miningField = modelEvaluator.getMiningField(name);
            if (dataField != null && miningField != null) {
                fieldLoaders.add(new DataFieldLoader(name, column, dataField, miningField));
                continue;
            }
            DerivedField derivedField = modelEvaluator.resolveDerivedField(name);
            if (derivedField != null) {
                fieldLoaders.add(new DerivedFieldLoader(name, column, derivedField));
                continue;
            }
            throw new InvalidFeatureException((PMMLObject)instanceField);
        }
        HashBasedTable result = HashBasedTable.create();
        InlineTable inlineTable = trainingInstances.getInlineTable();
        if (inlineTable != null) {
            Table<Integer, String, String> table = InlineTableUtil.getContent(inlineTable);
            Set rowKeys = table.rowKeySet();
            for (Integer rowKey : rowKeys) {
                Map rowValues = table.row((Object)rowKey);
                for (FieldLoader fieldLoader : fieldLoaders) {
                    result.put((Object)rowKey, (Object)fieldLoader.getName(), (Object)fieldLoader.load(rowValues));
                }
            }
        }
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldName name = knnInput.getField();
            DerivedField derivedField = modelEvaluator.resolveDerivedField(name);
            if (derivedField == null) continue;
            Set rowKeys = result.rowKeySet();
            for (Integer rowKey : rowKeys) {
                Map rowValues = result.row((Object)rowKey);
                if (rowValues.containsKey(name)) continue;
                ModelEvaluationContext context = new ModelEvaluationContext(null, modelEvaluator);
                context.declareAll(rowValues);
                result.put((Object)rowKey, (Object)name, (Object)ExpressionUtil.evaluate(derivedField, (EvaluationContext)context));
            }
        }
        return result;
    }

    private Map<Integer, BitSet> getInstanceFlags() {
        return this.getValue(NearestNeighborModelEvaluator.createInstanceFlagLoader(this), instanceFlagCache);
    }

    private static Callable<Map<Integer, BitSet>> createInstanceFlagLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Map<Integer, BitSet>>(){

            @Override
            public Map<Integer, BitSet> call() {
                return NearestNeighborModelEvaluator.loadInstanceFlags(modelEvaluator);
            }
        };
    }

    private static Map<Integer, BitSet> loadInstanceFlags(NearestNeighborModelEvaluator modelEvaluator) {
        LinkedHashMap result = Maps.newLinkedHashMap();
        Map<Integer, List<FieldValue>> valueMap = modelEvaluator.getValue(NearestNeighborModelEvaluator.createInstanceValueLoader(modelEvaluator), instanceValueCache);
        Maps.EntryTransformer<Integer, List<FieldValue>, BitSet> transformer = new Maps.EntryTransformer<Integer, List<FieldValue>, BitSet>(){

            public BitSet transformEntry(Integer key, List<FieldValue> value) {
                return MeasureUtil.toBitSet(value);
            }
        };
        result.putAll(Maps.transformEntries(valueMap, (Maps.EntryTransformer)transformer));
        return result;
    }

    private Map<Integer, List<FieldValue>> getInstanceValues() {
        return this.getValue(NearestNeighborModelEvaluator.createInstanceValueLoader(this), instanceValueCache);
    }

    private static Callable<Map<Integer, List<FieldValue>>> createInstanceValueLoader(final NearestNeighborModelEvaluator modelEvaluator) {
        return new Callable<Map<Integer, List<FieldValue>>>(){

            @Override
            public Map<Integer, List<FieldValue>> call() {
                return NearestNeighborModelEvaluator.loadInstanceValues(modelEvaluator);
            }
        };
    }

    private static Map<Integer, List<FieldValue>> loadInstanceValues(NearestNeighborModelEvaluator modelEvaluator) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)modelEvaluator.getModel();
        LinkedHashMap result = Maps.newLinkedHashMap();
        Table<Integer, FieldName, FieldValue> table = modelEvaluator.getValue(NearestNeighborModelEvaluator.createTrainingInstanceLoader(modelEvaluator), trainingInstanceCache);
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        ImmutableSortedSet rowKeys = ImmutableSortedSet.copyOf((Collection)table.rowKeySet());
        for (Integer rowKey : rowKeys) {
            ArrayList values = Lists.newArrayList();
            Map rowValues = table.row((Object)rowKey);
            for (KNNInput knnInput : knnInputs) {
                FieldValue value = (FieldValue)rowValues.get(knnInput.getField());
                values.add(value);
            }
            result.put(rowKey, values);
        }
        return result;
    }

    private static abstract class InstanceResult
    implements Comparable<InstanceResult> {
        private Integer id = null;
        private Double value = null;

        private InstanceResult(Integer id, Double value) {
            this.setId(id);
            this.setValue(value);
        }

        public abstract double getWeight(double var1);

        public Integer getId() {
            return this.id;
        }

        private void setId(Integer id) {
            this.id = id;
        }

        public Double getValue() {
            return this.value;
        }

        private void setValue(Double value) {
            this.value = value;
        }

        private static class Distance
        extends InstanceResult {
            private Distance(Integer id, Double value) {
                super(id, value);
            }

            @Override
            public int compareTo(InstanceResult that) {
                if (that instanceof Distance) {
                    return ClassificationMap.Type.DISTANCE.compare(this.getValue(), that.getValue());
                }
                throw new ClassCastException();
            }

            @Override
            public double getWeight(double threshold) {
                return 1.0 / (this.getValue() + threshold);
            }
        }

        private static class Similarity
        extends InstanceResult {
            private Similarity(Integer id, Double value) {
                super(id, value);
            }

            @Override
            public int compareTo(InstanceResult that) {
                if (that instanceof Similarity) {
                    return ClassificationMap.Type.SIMILARITY.compare(this.getValue(), that.getValue());
                }
                throw new ClassCastException();
            }

            @Override
            public double getWeight(double threshold) {
                throw new EvaluationException();
            }
        }
    }

    private static class DerivedFieldLoader
    extends FieldLoader {
        private DerivedField derivedField = null;

        private DerivedFieldLoader(FieldName name, String column, DerivedField derivedField) {
            super(name, column);
            this.setDerivedField(derivedField);
        }

        @Override
        public FieldValue prepare(String value) {
            return FieldValueUtil.create((Field)this.getDerivedField(), value);
        }

        public DerivedField getDerivedField() {
            return this.derivedField;
        }

        private void setDerivedField(DerivedField derivedField) {
            this.derivedField = derivedField;
        }
    }

    private static class DataFieldLoader
    extends FieldLoader {
        private DataField dataField = null;
        private MiningField miningField = null;

        private DataFieldLoader(FieldName name, String column, DataField dataField, MiningField miningField) {
            super(name, column);
            this.setDataField(dataField);
            this.setMiningField(miningField);
        }

        @Override
        public FieldValue prepare(String value) {
            return ArgumentUtil.prepare(this.getDataField(), this.getMiningField(), value);
        }

        public DataField getDataField() {
            return this.dataField;
        }

        private void setDataField(DataField dataField) {
            this.dataField = dataField;
        }

        public MiningField getMiningField() {
            return this.miningField;
        }

        private void setMiningField(MiningField miningField) {
            this.miningField = miningField;
        }
    }

    private static class IdentifierLoader
    extends FieldLoader {
        private IdentifierLoader(FieldName name, String column) {
            super(name, column);
        }

        @Override
        public FieldValue prepare(String value) {
            return FieldValueUtil.create(DataType.STRING, OpType.CATEGORICAL, (Object)value);
        }
    }

    private static abstract class FieldLoader {
        private FieldName name = null;
        private String column = null;

        private FieldLoader(FieldName name, String column) {
            this.setName(name);
            this.setColumn(column);
        }

        public abstract FieldValue prepare(String var1);

        public FieldValue load(Map<String, String> values) {
            String value = values.get(this.getColumn());
            return this.prepare(value);
        }

        public FieldName getName() {
            return this.name;
        }

        private void setName(FieldName name) {
            this.name = name;
        }

        public String getColumn() {
            return this.column;
        }

        private void setColumn(String column) {
            this.column = column;
        }
    }
}

