/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.classifier.knn;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.function.BiFunction;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class KNNEngine {
    private static final Logger LOG = LoggerFactory.getLogger(KNNEngine.class);
    private Tuple predict;
    private List<Tuple> trainingData;
    private int k;
    public BiFunction<double[], double[], Double> euclideanDistance = (v1, v2) -> Math.sqrt(Arrays.stream(VectorUtils.zip(v1, v2, (x, y) -> Math.pow(x - y, 2.0))).parallel().sum());
    public BiFunction<double[], double[], Double> chebyshevDistance = (v1, v2) -> Math.sqrt(Arrays.stream(VectorUtils.zip(v1, v2, (x, y) -> Math.abs(x - y))).max().getAsDouble());
    public BiFunction<double[], double[], Double> manhattanDistance = (v1, v2) -> Math.sqrt(Arrays.stream(VectorUtils.zip(v1, v2, (x, y) -> Math.abs(x - y))).parallel().sum());

    public KNNEngine(Tuple predict, List<Tuple> trainingData, int k) {
        this.predict = predict;
        this.trainingData = trainingData;
        this.k = k;
    }

    public void getDistance(BiFunction<double[], double[], Double> distanceFunction) {
        for (Tuple tuple : this.trainingData) {
            if (this.predict.featureVector.length != tuple.featureVector.length) {
                LOG.error("2 Vectors must has same dimension.");
                return;
            }
            tuple.distance = distanceFunction.apply(this.predict.featureVector, tuple.featureVector);
        }
    }

    public String getResult() {
        HashMap<String, Integer> resultMap = new HashMap<String, Integer>();
        Collections.sort(this.trainingData, (tuple1, tuple2) -> {
            double diff = tuple1.distance - tuple2.distance;
            if (Math.abs(diff) < Double.MIN_VALUE) {
                return 0;
            }
            return Double.compare(tuple1.distance, tuple2.distance);
        });
        for (int i = 0; i < this.k; ++i) {
            Tuple tuple = this.trainingData.get(i);
            int count = resultMap.containsKey(tuple.label) ? (Integer)resultMap.get(tuple.label) : 0;
            resultMap.put(tuple.label, ++count);
        }
        String maxVote = "";
        int maxCount = 0;
        int maxCountEntryNumber = 0;
        for (String label : resultMap.keySet()) {
            int currentCount = (Integer)resultMap.get(label);
            if (currentCount == maxCount) {
                ++maxCountEntryNumber;
                continue;
            }
            if (currentCount <= maxCount) continue;
            maxCount = currentCount;
            maxVote = label;
            maxCountEntryNumber = 1;
        }
        if (maxCountEntryNumber != 1) {
            LOG.info("Equal Max Vote, take the first max!");
        }
        this.predict.label = maxVote;
        return maxVote;
    }
}

