/* 
 * Copyright (C) 2016 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */
package dulab.adap.common.types;

import dulab.adap.common.distances.Distance;
import dulab.adap.common.distances.WeightedDotProductDistance;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 *
 * @author aleksandrsmirnov
 */
public class BallTree implements Serializable
{
    private final SparseMatrix data;
    private final WeightedDotProductDistance distanceFunction;
    
    private BallNode root;
    
    private double progress;
    
    // ------------------------------------------------------------------------
    // ----- Constructors -----------------------------------------------------
    // ------------------------------------------------------------------------
    
    public BallTree(SparseMatrix data, Distance distanceFunction) {
        this(data, distanceFunction, null);
    }

    public BallTree(SparseMatrix data, Distance distanceFunction, 
            BallNode root)
    {
        this.progress = 0f;
        
        this.data = data;
        this.distanceFunction = (WeightedDotProductDistance) distanceFunction;
        this.root = root;
    }
    
    // ------------------------------------------------------------------------
    // ----- Methods ----------------------------------------------------------
    // ------------------------------------------------------------------------
    
    public void build() {
        this.root = buildTree();
    }
    
    private int closestIndex(SparseVector vector, List <Integer> indices)
    {
        final int size = indices.size();
        
        int result = 0;
        double min = Double.MAX_VALUE;
        
        for (int i = 0; i < size; ++i) 
        {   
            double d = this.distanceFunction.call(
                    vector, 
                    this.distanceFunction.scale(this.data.get(indices.get(i))));
            
            if (d < min) {
                min = d;
                result = i;
            }
        }
        return result;
    }
    
    private BallNode buildTree() 
    {   
        List <List <Integer>> layerIndices = new ArrayList <> ();
        List <List <Integer>> nextLayerIndices = new ArrayList <> ();
        
        List <BallNode> layer = new ArrayList <> ();
        
        List <Integer> indices = new ArrayList <> (this.data.nrows() - 1);
        for (int i = 0; i < this.data.nrows(); ++i)
            indices.add(i);
        
        int rootID = closestIndex(this.data.mean(), indices);
        indices.remove(rootID);
        
        // Create a root of the tree
        BallNode root = new BallNode(rootID, null);
        layer.add(root);
        layerIndices.add(indices);
        
        List <BallNode> nextLayer = new ArrayList <> ();
        List <Double> distances = new ArrayList(indices.size());
        
        final int expectedNumberOfLayers = 2 * (int) java.lang.Math.round(
                java.lang.Math.log(this.data.nrows())
                / java.lang.Math.log(2));
        int layerCount = 0;
        
        while (layer.size() > 0) 
        {
            System.out.println(Integer.toString(layerCount) + " : " + layer.size());
            
            for (int layerIndex = 0; layerIndex < layer.size(); ++layerIndex) 
            {
                indices = layerIndices.get(layerIndex);
                BallNode node = layer.get(layerIndex);
                SparseVector center = this.data.get(node.getID());
                int size = indices.size();
                
                // --------------------------------------------------------
                // Step 1. Find the vector that is farthest from the center
                // --------------------------------------------------------
                
                int index1 = 0;
                double maxDistance = 0.0;
                
                for (final int index : indices) 
                {
                    double d = this.distanceFunction.call(
                            center, this.data.get(index));
                    
                    if (d > maxDistance) {
                        maxDistance = d;
                        index1 = index;
                    }
                }
                
                SparseVector vector1 = this.data.get(index1);
                
                if (vector1 == null) continue;
                
                node.setRadius(maxDistance);
                
                // -----------------------------------------------------
                // Step 2. Find the vector that is farthest from vector1
                // -----------------------------------------------------
                
                int index2 = 0;
                maxDistance = 0f;
                
                for (final int index : indices) 
                {
                    double d = this.distanceFunction.call(
                            vector1, this.data.get(index));
                    
                    if (d > maxDistance) {
                        maxDistance = d;
                        index2 = index;
                    }
                    distances.add(d);
                }
                
                SparseVector vector2 = this.data.get(index2);
                
                if (vector2 == null) continue;
                
                // ---------------------------------------------------
                // Step 3. Assign vectors to either vector1 or vector2
                // ---------------------------------------------------
                
                List <Integer> leftIndices = new ArrayList <> (size / 2);
                List <Integer> rightIndices = new ArrayList <> (size / 2);
                
                if (vector1 != vector2)
                {
                    for (int i = 0; i < size; ++i) 
                    {
                        int index = indices.get(i);
                        
                        double d = this.distanceFunction.call(
                                vector2, this.data.get(index));
                        
                        if (distances.get(i) < d) // Distance to vector1 is less then the distance to vector2
                            leftIndices.add(index);
                        else // Distance to vector2 is less then the distance to vector1
                            rightIndices.add(index);
                        
                    }
                } else // All vectors in the node are the same
                {
                    leftIndices.add(indices.get(0));
                 
                    for (int i = 1; i < size; ++i)
                        rightIndices.add(indices.get(i));
                }
                
                // ----------------------------------------------
                // Step 4.a. Create BallNode for the left indices
                // ----------------------------------------------
                
                if (leftIndices.size() > 0) 
                {   
                    SparseVector mean = new SparseVector();
                    for (int index : leftIndices)
                        mean.add(this.distanceFunction.scale(this.data.get(index)), false);
                    mean.multiply(1f / leftIndices.size(), false);
                    
                    int leftIndex = closestIndex(mean, leftIndices);
                    int leftID = leftIndices.get(leftIndex);
                    leftIndices.remove(leftIndex);
                    
                    node.left = new BallNode(leftID, node);
                    
                    if (leftIndices.size() > 0) {
                        nextLayer.add(node.left);
                        nextLayerIndices.add(leftIndices);
                    }
                }
                
                // -----------------------------------------------
                // Step 4.b. Create BallNode for the right indices
                // -----------------------------------------------
                
                if (rightIndices.size() > 0) 
                {
                    SparseVector mean = new SparseVector();
                    for (int index : rightIndices)
                        mean.add(this.distanceFunction.scale(this.data.get(index)), false);
                    mean.multiply(1f / rightIndices.size(), false);
                    
                    int rightIndex = closestIndex(mean, rightIndices);
                    int rightID = rightIndices.get(rightIndex);
                    rightIndices.remove(rightIndex);
                    
                    node.right = new BallNode(rightID, node);
                    
                    if (rightIndices.size() > 0) {
                        nextLayer.add(node.right);
                        nextLayerIndices.add(rightIndices);
                    }
                }
            }
            
            layer.clear();
            layer.addAll(nextLayer);
            nextLayer.clear();
            
            layerIndices.clear();
            layerIndices.addAll(nextLayerIndices);
            nextLayerIndices.clear();
            
            distances.clear();
            
            // Set maximum possible progress to 0.99 to make sure that
            // progress hits value 1.0 only when the loop is finished
            this.progress = java.lang.Math.min(
                    (double) ++layerCount / expectedNumberOfLayers,
                    0.99);
        }
        
        this.progress = 1.0;
        
        return root;
    }
    
    private List <Integer> propagateTree(SparseVector vector, double threshold) 
    {
        Map <Double, Integer> mapResult = new TreeMap <> ();
        
        List <BallNode> layer = new ArrayList <> ();
        layer.add(this.root);
        
        List <BallNode> nextLayer = new ArrayList <> ();
        
        while (layer.size() > 0)
        {
            for (final BallNode node : layer) 
            {
                final SparseVector center = this.data.get(node.getID());
                final BallNode left = node.left;
                final BallNode right = node.right;
                
                double d = this.distanceFunction.call(vector, center);
                
                if (d < threshold)
                    mapResult.put(d, node.getID());
                
                if (left != null && this.distanceFunction.call(
                        vector, this.data.get(left.getID())) 
                        < threshold + left.getRadius())
                {
                    nextLayer.add(left);
                }
                
                if (right != null && this.distanceFunction.call(
                        vector, this.data.get(right.getID()))
                        < threshold + right.getRadius())
                {
                    nextLayer.add(right);
                }
            }
            layer.clear();
            layer.addAll(nextLayer);
            nextLayer.clear();
        }
        
        List <Integer> listResult = new ArrayList <> (mapResult.values());
        Collections.reverse(listResult);
        
        return listResult;
    }
    
    // ------------------------------------------------------------------------
    // ----- Properties -------------------------------------------------------
    // ------------------------------------------------------------------------
    
    public BallNode root() {return this.root;}
    
    public SparseMatrix data() {return this.data;}
    
    public Distance distanceFunction() {return this.distanceFunction;}
    
    public List <SparseVector> getCloseVectors(SparseVector vector, 
            double threshold)
    {
        List <SparseVector> result = new ArrayList <> ();
        
        for (int index : propagateTree(vector, threshold))
            result.add(this.data.get(index));
        
        return result;
    }
    
    public void getCloseVectors(SparseVector vector, 
            float threshold, List <SparseVector> result)
    {
        for (int index : propagateTree(vector, threshold))
            result.add(this.data.get(index));
    }
    
    public double getProgress() {return this.progress;}
    
    static public String printTree(final BallNode node) 
    {
        if (node == null) return "";
        
        String result = node.toString() + "\n";
        
        result += printTree(node.left);
        result += printTree(node.right);
        
        return result;
    }
}
