/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.base.Function;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import org.dmg.pmml.CategoricalScoringMethodType;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.ContinuousScoringMethodType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.InstanceField;
import org.dmg.pmml.InstanceFields;
import org.dmg.pmml.KNNInput;
import org.dmg.pmml.KNNInputs;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Model;
import org.dmg.pmml.NearestNeighborModel;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TableLocator;
import org.dmg.pmml.TrainingInstances;
import org.jpmml.evaluator.ArgumentUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InlineTableUtil;
import org.jpmml.evaluator.InstanceClassificationMap;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MeasureUtil;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.VoteCounter;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.ModelManager;
import org.jpmml.manager.UnsupportedFeatureException;

public class NearestNeighborModelEvaluator
extends ModelEvaluator<NearestNeighborModel> {
    private static final Cache<NearestNeighborModel, PMML> pmmlCache = CacheBuilder.newBuilder().weakKeys().weakValues().build();
    private static final LoadingCache<NearestNeighborModel, Table<Integer, FieldName, FieldValue>> trainingInstanceCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<NearestNeighborModel, Table<Integer, FieldName, FieldValue>>(){

        public Table<Integer, FieldName, FieldValue> load(NearestNeighborModel nearestNeighborModel) {
            PMML pmml = (PMML)pmmlCache.getIfPresent((Object)nearestNeighborModel);
            if (pmml == null) {
                throw new EvaluationException();
            }
            return NearestNeighborModelEvaluator.parseTrainingInstances(pmml, nearestNeighborModel);
        }
    });

    public NearestNeighborModelEvaluator(PMML pmml) {
        this(pmml, (NearestNeighborModel)NearestNeighborModelEvaluator.find((List)pmml.getModels(), NearestNeighborModel.class));
    }

    public NearestNeighborModelEvaluator(PMML pmml, NearestNeighborModel nearestNeighborModel) {
        super(pmml, nearestNeighborModel);
    }

    public String getSummary() {
        return "k-Nearest neighbors model";
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
        Map<FieldName, InstanceClassificationMap> predictions;
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        if (!nearestNeighborModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)nearestNeighborModel);
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
        context.pushFrame(arguments);
        MiningFunctionType miningFunction = nearestNeighborModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: 
            case CLASSIFICATION: 
            case MIXED: {
                predictions = this.evaluateMixed(context);
                break;
            }
            case CLUSTERING: {
                predictions = this.evaluateClustering(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, InstanceClassificationMap> evaluateMixed(ModelManagerEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        List<InstanceResult> instanceResults = this.evaluate(context);
        List nearestInstanceResults = Lists.newArrayList(instanceResults);
        Collections.sort(nearestInstanceResults);
        nearestInstanceResults = nearestInstanceResults.subList(0, nearestNeighborModel.getNumberOfNeighbors());
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                return row.toString();
            }
        };
        String idField = nearestNeighborModel.getInstanceIdVariable();
        if (idField != null) {
            function = this.createIdentifierResolver(FieldName.create((String)idField), table);
        }
        LinkedHashMap result = Maps.newLinkedHashMap();
        List predictedFields = this.getPredictedFields();
        for (FieldName predictedField : predictedFields) {
            Object value;
            DataField dataField = this.getDataField(predictedField);
            OpType opType = dataField.getOptype();
            switch (opType) {
                case CONTINUOUS: {
                    value = this.calculateContinuousTarget(predictedField, nearestInstanceResults, table);
                    break;
                }
                case CATEGORICAL: {
                    value = this.calculateCategoricalTarget(predictedField, nearestInstanceResults, table);
                    break;
                }
                default: {
                    throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum)opType);
                }
            }
            result.put(predictedField, this.createDistanceMap(value, instanceResults, function));
        }
        return result;
    }

    private Map<FieldName, InstanceClassificationMap> evaluateClustering(ModelManagerEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        List<InstanceResult> instanceResults = this.evaluate(context);
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        String idField = nearestNeighborModel.getInstanceIdVariable();
        if (idField == null) {
            throw new InvalidFeatureException((PMMLObject)nearestNeighborModel);
        }
        Function<Integer, String> function = this.createIdentifierResolver(FieldName.create((String)idField), table);
        return Collections.singletonMap(this.getTargetField(), this.createDistanceMap(null, instanceResults, function));
    }

    private List<InstanceResult> evaluate(ModelManagerEvaluationContext context) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure();
        Measure measure = comparisonMeasure.getMeasure();
        if (!MeasureUtil.isDistance(measure)) {
            throw new UnsupportedFeatureException((PMMLObject)measure);
        }
        ArrayList result = Lists.newArrayList();
        ArrayList values = Lists.newArrayList();
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldValue value = ExpressionUtil.evaluate(knnInput.getField(), (EvaluationContext)context);
            values.add(value);
        }
        Double adjustment = MeasureUtil.calculateAdjustment(values);
        Table<Integer, FieldName, FieldValue> table = this.getTrainingInstances();
        ImmutableSortedSet rowKeys = ImmutableSortedSet.copyOf((Collection)table.rowKeySet());
        for (Integer rowKey : rowKeys) {
            Map rowValues = table.row((Object)rowKey);
            ArrayList instanceValues = Lists.newArrayList();
            for (KNNInput knnInput : knnInputs) {
                FieldValue instanceValue = (FieldValue)rowValues.get(knnInput.getField());
                instanceValues.add(instanceValue);
            }
            Double distance = MeasureUtil.evaluateDistance(comparisonMeasure, knnInputs.getKNNInputs(), values, instanceValues, adjustment);
            result.add(new InstanceResult(rowKey, distance));
        }
        return result;
    }

    private Double calculateContinuousTarget(FieldName name, List<InstanceResult> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        double sum = 0.0;
        ContinuousScoringMethodType continuousScoringMethod = nearestNeighborModel.getContinuousScoringMethod();
        block4: for (InstanceResult instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (value == null) {
                throw new MissingFieldException(name);
            }
            Number number = value.asNumber();
            switch (continuousScoringMethod) {
                case AVERAGE: {
                    sum += number.doubleValue();
                    continue block4;
                }
                case WEIGHTED_AVERAGE: {
                    sum += instanceResult.getWeight(nearestNeighborModel.getThreshold()) * number.doubleValue();
                    continue block4;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)continuousScoringMethod);
        }
        return sum / (double)instanceResults.size();
    }

    private Object calculateCategoricalTarget(FieldName name, List<InstanceResult> instanceResults, Table<Integer, FieldName, FieldValue> table) {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        VoteCounter<Object> counter = new VoteCounter<Object>();
        CategoricalScoringMethodType categoricalScoringMethod = nearestNeighborModel.getCategoricalScoringMethod();
        block4: for (InstanceResult instanceResult : instanceResults) {
            FieldValue value = (FieldValue)table.get((Object)instanceResult.getId(), (Object)name);
            if (value == null) {
                throw new MissingFieldException(name);
            }
            Object object = value.getValue();
            switch (categoricalScoringMethod) {
                case MAJORITY_VOTE: {
                    counter.increment(object);
                    continue block4;
                }
                case WEIGHTED_MAJORITY_VOTE: {
                    counter.increment(object, instanceResult.getWeight(nearestNeighborModel.getThreshold()));
                    continue block4;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)nearestNeighborModel, (Enum)categoricalScoringMethod);
        }
        Set winners = counter.getWinners();
        if (winners.size() > 1) {
            Map column = table.column((Object)name);
            LinkedHashMultiset multiset = LinkedHashMultiset.create();
            multiset.addAll(column.values());
            counter.clear();
            for (Object winner : winners) {
                counter.increment(winner, Double.valueOf(multiset.count(winner)));
            }
            winners = counter.getWinners();
            if (winners.size() > 1) {
                throw new EvaluationException();
            }
        }
        return Iterables.getFirst(winners, null);
    }

    private Function<Integer, String> createIdentifierResolver(final FieldName name, final Table<Integer, FieldName, FieldValue> table) {
        Function<Integer, String> function = new Function<Integer, String>(){

            public String apply(Integer row) {
                FieldValue value = (FieldValue)table.get((Object)row, (Object)name);
                if (value == null) {
                    throw new MissingFieldException(name);
                }
                return value.asString();
            }
        };
        return function;
    }

    private InstanceClassificationMap createDistanceMap(Object value, List<InstanceResult> instanceResults, Function<Integer, String> function) {
        InstanceClassificationMap result = new InstanceClassificationMap(value);
        for (InstanceResult instanceResult : instanceResults) {
            result.put(function.apply((Object)instanceResult.getId()), instanceResult.getDistance());
        }
        return result;
    }

    private Table<Integer, FieldName, FieldValue> getTrainingInstances() {
        NearestNeighborModel nearestNeighborModel = (NearestNeighborModel)this.getModel();
        try {
            Callable<PMML> callable = new Callable<PMML>(){

                @Override
                public PMML call() {
                    return NearestNeighborModelEvaluator.this.getPMML();
                }
            };
            pmmlCache.get((Object)nearestNeighborModel, (Callable)callable);
        }
        catch (ExecutionException ee) {
            throw new EvaluationException();
        }
        return this.getValue(trainingInstanceCache);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Table<Integer, FieldName, FieldValue> parseTrainingInstances(PMML pmml, NearestNeighborModel nearestNeighborModel) {
        DerivedField derivedField;
        TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances();
        TableLocator tableLocator = trainingInstances.getTableLocator();
        if (tableLocator != null) {
            throw new UnsupportedFeatureException((PMMLObject)tableLocator);
        }
        ModelManager modelManager = new ModelManager(pmml, (Model)nearestNeighborModel);
        String idField = nearestNeighborModel.getInstanceIdVariable();
        ArrayList fieldLoaders = Lists.newArrayList();
        InstanceFields instanceFields = trainingInstances.getInstanceFields();
        for (InstanceField instanceField : instanceFields) {
            String field = instanceField.getField();
            String column = instanceField.getColumn();
            FieldName name = FieldName.create((String)field);
            if (idField != null && idField.equals(field)) {
                fieldLoaders.add(new IdentifierLoader(name, column));
                continue;
            }
            DataField dataField = modelManager.getDataField(name);
            MiningField miningField = modelManager.getMiningField(name);
            if (dataField != null && miningField != null) {
                fieldLoaders.add(new DataFieldLoader(name, column, dataField, miningField));
                continue;
            }
            derivedField = modelManager.resolveField(name);
            if (derivedField != null) {
                fieldLoaders.add(new DerivedFieldLoader(name, column, derivedField));
                continue;
            }
            throw new InvalidFeatureException((PMMLObject)instanceField);
        }
        HashBasedTable result = HashBasedTable.create();
        InlineTable inlineTable = trainingInstances.getInlineTable();
        if (inlineTable != null) {
            Table<Integer, String, String> table = InlineTableUtil.getContent(inlineTable);
            Set rowKeys = table.rowKeySet();
            for (Integer rowKey : rowKeys) {
                Map rowValues = table.row((Object)rowKey);
                for (FieldLoader fieldLoader : fieldLoaders) {
                    result.put((Object)rowKey, (Object)fieldLoader.getName(), (Object)fieldLoader.load(rowValues));
                }
            }
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(modelManager);
        KNNInputs knnInputs = nearestNeighborModel.getKNNInputs();
        for (KNNInput knnInput : knnInputs) {
            FieldName name = knnInput.getField();
            derivedField = modelManager.resolveField(name);
            if (derivedField == null) continue;
            Set rowKeys = result.rowKeySet();
            for (Integer rowKey : rowKeys) {
                Map rowValues = result.row((Object)rowKey);
                if (rowValues.containsKey(name)) continue;
                context.pushFrame(rowValues);
                try {
                    result.put((Object)rowKey, (Object)name, (Object)ExpressionUtil.evaluate(derivedField, (EvaluationContext)context));
                }
                finally {
                    context.popFrame();
                }
            }
        }
        return result;
    }

    private static class InstanceResult
    implements Comparable<InstanceResult> {
        private Integer id = null;
        private Double distance = null;

        public InstanceResult(Integer id, Double distance) {
            this.setId(id);
            this.setDistance(distance);
        }

        @Override
        public int compareTo(InstanceResult that) {
            return this.getDistance().compareTo(that.getDistance());
        }

        public double getWeight(double threshold) {
            Double distance = this.getDistance();
            return 1.0 / (distance + threshold);
        }

        public Integer getId() {
            return this.id;
        }

        private void setId(Integer id) {
            this.id = id;
        }

        public Double getDistance() {
            return this.distance;
        }

        private void setDistance(Double distance) {
            this.distance = distance;
        }
    }

    private static class DerivedFieldLoader
    extends FieldLoader {
        private DerivedField derivedField = null;

        public DerivedFieldLoader(FieldName name, String column, DerivedField derivedField) {
            super(name, column);
            this.setDerivedField(derivedField);
        }

        @Override
        public FieldValue prepare(String value) {
            return FieldValueUtil.create((Field)this.getDerivedField(), value);
        }

        public DerivedField getDerivedField() {
            return this.derivedField;
        }

        private void setDerivedField(DerivedField derivedField) {
            this.derivedField = derivedField;
        }
    }

    private static class DataFieldLoader
    extends FieldLoader {
        private DataField dataField = null;
        private MiningField miningField = null;

        private DataFieldLoader(FieldName name, String column, DataField dataField, MiningField miningField) {
            super(name, column);
            this.setDataField(dataField);
            this.setMiningField(miningField);
        }

        @Override
        public FieldValue prepare(String value) {
            return ArgumentUtil.prepare(this.getDataField(), this.getMiningField(), value);
        }

        public DataField getDataField() {
            return this.dataField;
        }

        private void setDataField(DataField dataField) {
            this.dataField = dataField;
        }

        public MiningField getMiningField() {
            return this.miningField;
        }

        private void setMiningField(MiningField miningField) {
            this.miningField = miningField;
        }
    }

    private static class IdentifierLoader
    extends FieldLoader {
        private IdentifierLoader(FieldName name, String column) {
            super(name, column);
        }

        @Override
        public FieldValue prepare(String value) {
            return FieldValueUtil.create(DataType.STRING, OpType.CATEGORICAL, value);
        }
    }

    private static abstract class FieldLoader {
        private FieldName name = null;
        private String column = null;

        private FieldLoader(FieldName name, String column) {
            this.setName(name);
            this.setColumn(column);
        }

        public abstract FieldValue prepare(String var1);

        public FieldValue load(Map<String, String> values) {
            String value = values.get(this.getColumn());
            return this.prepare(value);
        }

        public FieldName getName() {
            return this.name;
        }

        private void setName(FieldName name) {
            this.name = name;
        }

        public String getColumn() {
            return this.column;
        }

        private void setColumn(String column) {
            this.column = column;
        }
    }
}

