package ch.sahits.game.openpatrician.engine.sea;

import ch.sahits.game.openpatrician.utilities.annotation.ClassCategory;
import ch.sahits.game.openpatrician.utilities.annotation.DependentInitialisation;
import ch.sahits.game.openpatrician.utilities.annotation.EClassCategory;
import ch.sahits.game.openpatrician.utilities.annotation.MapType;
import ch.sahits.game.openpatrician.engine.sea.model.GraphAStar;
import ch.sahits.game.openpatrician.engine.sea.model.NodeData;
import ch.sahits.game.openpatrician.model.city.ICity;
import ch.sahits.game.openpatrician.model.map.IMap;
import ch.sahits.game.openpatrician.model.initialisation.StartNewGameBean;
import com.carrotsearch.hppc.ObjectDoubleMap;
import com.carrotsearch.hppc.cursors.ObjectDoubleCursor;
import com.google.common.eventbus.AsyncEventBus;
import javafx.geometry.Point2D;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;

/**
 * Implementation of the A* pathfinding algorithm.
 * Implementation is based on this code review:
 * http://codereview.stackexchange.com/questions/38376/a-search-algorithm
 * @author Andi Hotz, (c) Sahits GmbH, 2016
 *         Created on Jan 01, 2016
 */
@ClassCategory(EClassCategory.SINGLETON_BEAN)
@Component
@Lazy
@DependentInitialisation(StartNewGameBean.class)
public class AStar {
    private final Logger logger = LogManager.getLogger(getClass());
    private GraphAStar<Point2D> graph;
    @Autowired
    private AStarGraphProvider graphProvider;
    @Autowired
    private IMap map;
    @Autowired
    @Qualifier("serverThreadPool")
    private Executor serverThreadPool;
    @Autowired
    @Qualifier("paralleizationExecutor")
    private ExecutorService paralleizationExecutor;
    @Autowired
    @Qualifier("serverClientEventBus")
    private AsyncEventBus clientServerEventBus;
    /**
     * Caching the paths with destination (outer key) and source (inner key)
     */
    @MapType(key = Point2D.class, value = Map.class)
    private Map<Point2D, Map<Point2D, List<Point2D>>> pathCache = new ConcurrentHashMap<>();
    
    /**
     * Implements the A-star algorithm and returns the path from source to destination
     *
     * @param source        the source nodeid
     * @param destination   the destination nodeid
     * @return              the path from source to destination
     */
    public List<Point2D> findPath(final Point2D source, final Point2D destination) {

        while (graph == null) {
            graph = graphProvider.getGraph();
            Thread.yield();
        }

        List<Point2D> pathList = findCached(destination, source);
        if (pathList != null) {
            return pathList;
        }

        ensureSetup(source, destination);


        /**
         * http://stackoverflow.com/questions/20344041/why-does-priority-queue-has-default-initial-capacity-of-11
         */
        final Queue<NodeData<Point2D>> openQueue = new PriorityQueue<>(11, new NodeComparator());
        NodeData<Point2D> sourceNodeData = graph.getNodeData(source);
        sourceNodeData.setG(source, 0);
        sourceNodeData.calcF(destination, source);
        openQueue.add(sourceNodeData);

        final Map<Point2D, Point2D> path = new HashMap<>();
        final Set<NodeData<Point2D>> closedList = new HashSet<>();
        logger.trace("Find path from {} -> {}", source, destination);


        while (!openQueue.isEmpty()) {
            final NodeData<Point2D> nodeData = openQueue.poll();
            logger.trace("Inspect node: {}", nodeData.getNodeId());

            if (nodeData.getNodeId().equals(destination)) { // termination
                pathList = path(path, destination);
                Map<Point2D, List<Point2D>> sourceMap = pathCache.get(destination);
                if (sourceMap == null) {
                    sourceMap = new HashMap<>();
                    sourceMap.put(source, pathList);
                    pathCache.put(destination, sourceMap);
                } else {
                    sourceMap.put(source, pathList);
                }
                if (logger.isTraceEnabled()) {
                    StringBuilder sb = new StringBuilder();
                    sb.append("Found path: ");
                    for (Point2D point2D : pathList) {
                        sb.append(point2D).append(" ");
                    }
                    logger.trace(sb.toString());
                }
                return pathList;
            }

            closedList.add(nodeData);

            ObjectDoubleMap<NodeData<Point2D>> map = graph.edgesFrom(nodeData.getNodeId());
            for (ObjectDoubleCursor<NodeData<Point2D>> neighborEntry : map) {
                NodeData<Point2D> neighbor = neighborEntry.key;

                if (closedList.contains(neighbor)) continue;

                double distanceBetweenTwoNodes = neighborEntry.value;
                double tentativeG = distanceBetweenTwoNodes + nodeData.getG(source); // guess distance to source

                if (tentativeG <= neighbor.getG(source)) { // distance to the source is smaller than to for the previous node
                    neighbor.setG(source, tentativeG);
                    neighbor.calcF(destination, source);

                    path.put(neighbor.getNodeId(), nodeData.getNodeId());
                    if (!openQueue.contains(neighbor)) {
                        openQueue.add(neighbor);
                        logger.trace("Add neighbor {}", neighbor.getNodeId());
                    }
                }
            }
        }
        if (logger.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append("Partial path: ");
            for (NodeData<Point2D> nodeData : closedList) {
                sb.append(nodeData.getNodeId()).append(" ");
            }
            logger.trace(sb.toString());
        }
        return null; // there is no path
    }


    private List<Point2D> findCached(Point2D destination, Point2D source) {
        Map<Point2D, List<Point2D>> sourceMap = pathCache.get(destination);
        if (sourceMap != null) {
            return sourceMap.get(source);
        }
        return null;
    }

    private void ensureSetup(Point2D source, Point2D destination) {
        if (!graph.containsNode(source)) {
            boolean isCity = false;
            for (ICity city : map.getCities()) {
                if (source.equals(city.getCoordinates())) {
                    isCity = true;
                    break;
                }
            }
//            graphProvider.addSourcePoint(source, isCity);
            graphProvider.addDestinationPoint(source, isCity);
        }
        if (!graph.containsNode(destination)) {
            boolean isCity = false;
            for (ICity city : map.getCities()) {
                if (destination.equals(city.getCoordinates())) {
                    isCity = true;
                    break;
                }
            }
            graphProvider.addDestinationPoint(destination, isCity);
        }
        graphProvider.getGraph();
    }


    private List<Point2D> path(Map<Point2D, Point2D> path, Point2D destination) {
        assert path != null;
        assert destination != null;

        final List<Point2D> pathList = new ArrayList<>();
        pathList.add(destination);
        while (path.containsKey(destination)) {
            destination = path.get(destination);
            pathList.add(destination);
        }
        Collections.reverse(pathList);
        return pathList;
    }

    // extend comparator.
    private static class NodeComparator implements Comparator<NodeData<Point2D>> {
        public int compare(NodeData<Point2D> nodeFirst, NodeData<Point2D> nodeSecond) {
            if (nodeFirst.getF() > nodeSecond.getF()) return 1;
            if (nodeSecond.getF() > nodeFirst.getF()) return -1;
            return 0;
        }
    }
}
