/*
 * Decompiled with CFR 0.152.
 */
package org.tinspin.index.kdtree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.tinspin.index.PointEntry;
import org.tinspin.index.PointEntryDist;
import org.tinspin.index.PointIndex;
import org.tinspin.index.QueryIterator;
import org.tinspin.index.QueryIteratorKNN;
import org.tinspin.index.kdtree.KDEntryDist;
import org.tinspin.index.kdtree.KDIterator;
import org.tinspin.index.kdtree.Node;

public class KDTree<T>
implements PointIndex<T> {
    private static final String NL = System.lineSeparator();
    public static final boolean DEBUG = false;
    private final int dims;
    private int size = 0;
    private int modCount = 0;
    private boolean invariantBroken = false;
    private Node<T> root;
    private static final Comparator<KDEntryDist<?>> compKnn = (point1, point2) -> {
        double deltaDist = point1.dist() - point2.dist();
        return deltaDist < 0.0 ? -1 : (deltaDist > 0.0 ? 1 : 0);
    };

    public static void main(String ... args) {
        for (int i = 0; i < 10; ++i) {
            try {
                KDTree.test(i);
                continue;
            }
            catch (Exception e) {
                System.out.println("Failed with r=" + i);
                throw new RuntimeException(e);
            }
        }
    }

    private static void test(int r) {
        double[][] point_list = new double[500000][14];
        Random R = new Random(r);
        for (double[] p : point_list) {
            Arrays.setAll(p, i -> R.nextInt(100));
        }
        KDTree<double[]> tree = KDTree.create(point_list[0].length);
        for (double[] data : point_list) {
            tree.insert(data, data);
        }
        for (double[] key : point_list) {
            if (tree.containsExact(key)) continue;
            throw new IllegalStateException("" + Arrays.toString(key));
        }
        for (double[] key : point_list) {
            System.out.println(Arrays.toString((double[])tree.queryExact(key)));
        }
        for (double[] key : point_list) {
            System.out.println("kNN query: " + Arrays.toString(key));
            QueryIteratorKNN iter = tree.queryKNN(key, 1);
            if (!iter.hasNext()) {
                throw new IllegalStateException("kNN() failed: " + Arrays.toString(key));
            }
            double[] answer = ((PointEntryDist)iter.next()).point();
            if (answer == key || Arrays.equals(answer, key)) continue;
            throw new IllegalStateException("Expected " + Arrays.toString(key) + " but got " + Arrays.toString(answer));
        }
        for (double[] key : point_list) {
            System.out.println("Removing: " + Arrays.toString(key));
            if (!tree.containsExact(key)) {
                throw new IllegalStateException("containsExact() failed: " + Arrays.toString(key));
            }
            double[] answer = (double[])tree.remove(key);
            if (answer == key || Arrays.equals(answer, key)) continue;
            throw new IllegalStateException("Expected " + Arrays.toString(key) + " but got " + Arrays.toString(answer));
        }
    }

    private KDTree(int dims) {
        this.dims = dims;
    }

    public static <T> KDTree<T> create(int dims) {
        return new KDTree<T>(dims);
    }

    @Override
    public void insert(double[] key, T value) {
        ++this.size;
        ++this.modCount;
        if (this.root == null) {
            this.root = new Node<T>(key, value, 0);
            return;
        }
        Node<T> n = this.root;
        while ((n = n.getClosestNodeOrAddPoint(key, value, this.dims)) != null) {
        }
    }

    public boolean containsExact(double[] key) {
        return this.findNodeExcat(key, new RemoveResult()) != null;
    }

    @Override
    public T queryExact(double[] key) {
        Node e = this.findNodeExcat(key, new RemoveResult());
        return e == null ? null : (T)e.getValue();
    }

    private Node<T> findNodeExcat(double[] key, RemoveResult<T> resultDepth) {
        if (this.root == null) {
            return null;
        }
        return this.invariantBroken ? this.findNodeExactSlow(key, this.root, null, resultDepth) : this.findNodeExcatFast(key, null, resultDepth);
    }

    private Node<T> findNodeExcatFast(double[] key, Node<T> parent, RemoveResult<T> resultDepth) {
        double nodeX;
        double keyX;
        Node<T> n = this.root;
        do {
            double[] nodeKey = n.getKey();
            nodeX = nodeKey[n.getDim()];
            keyX = key[n.getDim()];
            if (keyX == nodeX && Arrays.equals(key, nodeKey)) {
                resultDepth.pos = n.getDim();
                resultDepth.nodeParent = parent;
                return n;
            }
            parent = n;
        } while ((n = keyX >= nodeX ? n.getHi() : n.getLo()) != null);
        return n;
    }

    private Node<T> findNodeExactSlow(double[] key, Node<T> n, Node<T> parent, RemoveResult<T> resultDepth) {
        double nodeX;
        double keyX;
        do {
            double[] nodeKey = n.getKey();
            nodeX = nodeKey[n.getDim()];
            keyX = key[n.getDim()];
            if (keyX == nodeX) {
                Node<T> n2;
                if (Arrays.equals(key, nodeKey)) {
                    resultDepth.pos = n.getDim();
                    resultDepth.nodeParent = parent;
                    return n;
                }
                if (n.getLo() != null && (n2 = this.findNodeExactSlow(key, n.getLo(), parent, resultDepth)) != null) {
                    return n2;
                }
            }
            parent = n;
        } while ((n = keyX >= nodeX ? n.getHi() : n.getLo()) != null);
        return n;
    }

    @Override
    public T remove(double[] key) {
        if (this.root == null) {
            return null;
        }
        RemoveResult removeResult = new RemoveResult();
        Node eToRemove = this.findNodeExcat(key, removeResult);
        if (eToRemove == null) {
            return null;
        }
        ++this.modCount;
        Object value = eToRemove.getValue();
        if (eToRemove == this.root && this.size == 1) {
            this.root = null;
            this.size = 0;
            this.invariantBroken = false;
        }
        removeResult.nodeParent = null;
        while (eToRemove != null && !eToRemove.isLeaf()) {
            int pos = removeResult.pos;
            removeResult.node = null;
            if (eToRemove.getHi() != null) {
                removeResult.best = Double.POSITIVE_INFINITY;
                this.removeMinLeaf(eToRemove.getHi(), eToRemove, pos, removeResult);
            } else if (eToRemove.getLo() != null) {
                removeResult.best = Double.NEGATIVE_INFINITY;
                this.removeMaxLeaf(eToRemove.getLo(), eToRemove, pos, removeResult);
            }
            eToRemove.setKeyValue(removeResult.node.getKey(), removeResult.node.getValue());
            eToRemove = removeResult.node;
        }
        Node parent = removeResult.nodeParent;
        if (parent != null) {
            if (parent.getLo() == eToRemove) {
                parent.setLeft(null);
            } else if (parent.getHi() == eToRemove) {
                parent.setRight(null);
            } else {
                throw new IllegalStateException();
            }
        }
        --this.size;
        return value;
    }

    private void removeMinLeaf(Node<T> node, Node<T> parent, int pos, RemoveResult<T> result) {
        if (pos == node.getDim()) {
            if (node.getLo() != null) {
                this.removeMinLeaf(node.getLo(), node, pos, result);
            } else if (node.getKey()[pos] <= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = node.getKey()[pos];
                result.pos = node.getDim();
            }
        } else {
            double localX = node.getKey()[pos];
            if (localX <= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = localX;
                result.pos = node.getDim();
            }
            if (node.getLo() != null) {
                this.removeMinLeaf(node.getLo(), node, pos, result);
            }
            if (node.getHi() != null) {
                this.removeMinLeaf(node.getHi(), node, pos, result);
            }
        }
    }

    private void removeMaxLeaf(Node<T> node, Node<T> parent, int pos, RemoveResult<T> result) {
        if (pos == node.getDim()) {
            if (node.getHi() != null) {
                this.removeMaxLeaf(node.getHi(), node, pos, result);
            } else if (node.getKey()[pos] >= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = node.getKey()[pos];
                result.pos = node.getDim();
                this.invariantBroken |= result.best == node.getKey()[pos];
            }
        } else {
            double localX = node.getKey()[pos];
            if (localX >= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = localX;
                result.pos = node.getDim();
                this.invariantBroken |= result.best == localX;
            }
            if (node.getLo() != null) {
                this.removeMaxLeaf(node.getLo(), node, pos, result);
            }
            if (node.getHi() != null) {
                this.removeMaxLeaf(node.getHi(), node, pos, result);
            }
        }
    }

    @Override
    public T update(double[] oldKey, double[] newKey) {
        if (this.root == null) {
            return null;
        }
        T value = this.remove(oldKey);
        this.insert(newKey, value);
        return value;
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public void clear() {
        this.size = 0;
        this.root = null;
        this.invariantBroken = false;
        ++this.modCount;
    }

    public KDIterator<T> query(double[] min, double[] max) {
        return new KDIterator(this, min, max);
    }

    static boolean isEnclosed(double[] point, double[] min, double[] max) {
        for (int i = 0; i < point.length; ++i) {
            if (!(point[i] < min[i]) && !(point[i] > max[i])) continue;
            return false;
        }
        return true;
    }

    private static double distance(double[] p1, double[] p2) {
        double dist = 0.0;
        for (int i = 0; i < p1.length; ++i) {
            double d = p1[i] - p2[i];
            dist += d * d;
        }
        return Math.sqrt(dist);
    }

    public KDEntryDist<T> nnQuery(double[] center) {
        if (this.root == null) {
            return null;
        }
        KDEntryDist candidate = new KDEntryDist(null, Double.POSITIVE_INFINITY);
        this.rangeSearch1NN(this.root, center, candidate, Double.POSITIVE_INFINITY);
        return candidate;
    }

    private double rangeSearch1NN(Node<T> node, double[] center, KDEntryDist<T> candidate, double maxRange) {
        int pos = node.getDim();
        if (node.getLo() != null && (center[pos] < node.getKey()[pos] || node.getHi() == null)) {
            maxRange = this.rangeSearch1NN(node.getLo(), center, candidate, maxRange);
            if (center[pos] + maxRange >= node.getKey()[pos]) {
                maxRange = this.addCandidate(node, center, candidate, maxRange);
                if (node.getHi() != null) {
                    maxRange = this.rangeSearch1NN(node.getHi(), center, candidate, maxRange);
                }
            }
        } else if (node.getHi() != null) {
            maxRange = this.rangeSearch1NN(node.getHi(), center, candidate, maxRange);
            if (center[pos] <= node.getKey()[pos] + maxRange) {
                maxRange = this.addCandidate(node, center, candidate, maxRange);
                if (node.getLo() != null) {
                    maxRange = this.rangeSearch1NN(node.getLo(), center, candidate, maxRange);
                }
            }
        } else {
            maxRange = this.addCandidate(node, center, candidate, maxRange);
        }
        return maxRange;
    }

    private double addCandidate(Node<T> node, double[] center, KDEntryDist<T> candidate, double maxRange) {
        double dist = KDTree.distance(center, node.getKey());
        if (dist >= maxRange) {
            return maxRange;
        }
        candidate.set(node, dist);
        return dist;
    }

    public List<KDEntryDist<T>> knnQuery(double[] center, int k) {
        if (this.root == null) {
            return Collections.emptyList();
        }
        ArrayList<KDEntryDist<T>> candidates = new ArrayList<KDEntryDist<T>>(k);
        this.rangeSearchKNN(this.root, center, candidates, k, Double.POSITIVE_INFINITY);
        return candidates;
    }

    private double rangeSearchKNN(Node<T> node, double[] center, ArrayList<KDEntryDist<T>> candidates, int k, double maxRange) {
        int pos = node.getDim();
        if (node.getLo() != null && (center[pos] < node.getKey()[pos] || node.getHi() == null)) {
            maxRange = this.rangeSearchKNN(node.getLo(), center, candidates, k, maxRange);
            if (center[pos] + maxRange >= node.getKey()[pos]) {
                maxRange = this.addCandidate(node, center, candidates, k, maxRange);
                if (node.getHi() != null) {
                    maxRange = this.rangeSearchKNN(node.getHi(), center, candidates, k, maxRange);
                }
            }
        } else if (node.getHi() != null) {
            maxRange = this.rangeSearchKNN(node.getHi(), center, candidates, k, maxRange);
            if (center[pos] <= node.getKey()[pos] + maxRange) {
                maxRange = this.addCandidate(node, center, candidates, k, maxRange);
                if (node.getLo() != null) {
                    maxRange = this.rangeSearchKNN(node.getLo(), center, candidates, k, maxRange);
                }
            }
        } else {
            maxRange = this.addCandidate(node, center, candidates, k, maxRange);
        }
        return maxRange;
    }

    private double addCandidate(Node<T> node, double[] center, ArrayList<KDEntryDist<T>> candidates, int k, double maxRange) {
        KDEntryDist<T> cand;
        double dist = KDTree.distance(center, node.getKey());
        if (dist > maxRange) {
            return maxRange;
        }
        if (dist == maxRange && candidates.size() >= k) {
            return maxRange;
        }
        if (candidates.size() >= k) {
            cand = candidates.remove(k - 1);
            cand.set(node, dist);
        } else {
            cand = new KDEntryDist<T>(node, dist);
        }
        int insertionPos = Collections.binarySearch(candidates, cand, compKnn);
        insertionPos = insertionPos >= 0 ? insertionPos : -(insertionPos + 1);
        candidates.add(insertionPos, cand);
        return candidates.size() < k ? maxRange : candidates.get(candidates.size() - 1).dist();
    }

    @Override
    public String toStringTree() {
        StringBuilder sb = new StringBuilder();
        if (this.root == null) {
            sb.append("empty tree");
        } else {
            this.toStringTree(sb, this.root, 0);
        }
        return sb.toString();
    }

    private void toStringTree(StringBuilder sb, Node<T> node, int depth) {
        String prefix = "";
        for (int i = 0; i < depth; ++i) {
            prefix = prefix + ".";
        }
        prefix = prefix + " ";
        if (node.getLo() != null) {
            this.toStringTree(sb, node.getLo(), depth + 1);
        }
        sb.append(prefix + Arrays.toString(node.point()));
        sb.append(" v=" + node.value());
        sb.append(" l/r=");
        sb.append(node.getLo() == null ? null : Arrays.toString(node.getLo().point()));
        sb.append("/");
        sb.append(node.getHi() == null ? null : Arrays.toString(node.getHi().point()));
        sb.append(NL);
        if (node.getHi() != null) {
            this.toStringTree(sb, node.getHi(), depth + 1);
        }
    }

    public String toString() {
        return "KDTree;size=" + this.size + ";DEBUG=" + false + ";center=" + (this.root == null ? "null" : Arrays.toString(this.root.getKey()));
    }

    @Override
    public KDStats getStats() {
        KDStats s = new KDStats();
        if (this.root != null) {
            this.root.checkNode(s, 0);
        }
        return s;
    }

    @Override
    public int getDims() {
        return this.dims;
    }

    @Override
    public QueryIterator<PointEntry<T>> iterator() {
        if (this.root == null) {
            return this.query(new double[this.dims], new double[this.dims]);
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public KDEntryDist<T> query1NN(double[] center) {
        return this.nnQuery(center);
    }

    public KDQueryIteratorKNN<T> queryKNN(double[] center, int k) {
        return new KDQueryIteratorKNN(this, center, k);
    }

    @Override
    public int getNodeCount() {
        return this.getStats().getNodeCount();
    }

    @Override
    public int getDepth() {
        return this.getStats().getMaxDepth();
    }

    Node<T> getRoot() {
        return this.root;
    }

    public static class KDStats {
        int nNodes;
        int maxDepth;

        public int getNodeCount() {
            return this.nNodes;
        }

        public int getMaxDepth() {
            return this.maxDepth;
        }
    }

    private static class KDQueryIteratorKNN<T>
    implements QueryIteratorKNN<PointEntryDist<T>> {
        private Iterator<? extends PointEntryDist<T>> it;
        private final KDTree<T> tree;

        public KDQueryIteratorKNN(KDTree<T> tree, double[] center, int k) {
            this.tree = tree;
            this.reset(center, k);
        }

        @Override
        public boolean hasNext() {
            return this.it.hasNext();
        }

        @Override
        public PointEntryDist<T> next() {
            return this.it.next();
        }

        @Override
        public void reset(double[] center, int k) {
            this.it = this.tree.knnQuery(center, k).iterator();
        }
    }

    private static class RemoveResult<T> {
        Node<T> node = null;
        Node<T> nodeParent = null;
        double best;
        int pos;

        private RemoveResult() {
        }
    }
}

