package org.aika.corpus;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


import java.util.*;

public class ExpandNode {

    public static int MAX_SEARCH_STEPS = 10000;

    ExpandNode excludedParent;
    ExpandNode selectedParent;
    ExpandNode parent;
    ExpandNode nonConflictingBranch;
    ExpandNode conflictingBranch;

    long visited;
    Option refinement;
    double upperBound;
    double accumulatedWeight;
    HashSet<Option> conflicts = new HashSet<>();

    long debugId = debugIdCounter++;
    static long debugIdCounter = 0;


    public ExpandNode(ExpandNode parent, ExpandNode selectedParent, ExpandNode excludedParent) {
        this.parent = parent;
        this.selectedParent = selectedParent;
        this.excludedParent = excludedParent;
    }


    public static void computeSelectedOption(Document doc) {
        ArrayList<Option> results = new ArrayList<>();
        results.add(doc.bottom);

        int numClusters = computeClusters(doc);

        for(int clusterId = 0; clusterId < numClusters; clusterId++) {
            try {
                doc.selectedExpandNode = null;
                ExpandNode root = new ExpandNode(null, null, null);
                root.refinement = doc.bottom;
                root.accumulatedWeight = 0.0;

                root.nonConflictingBranch = new ExpandNode(root, root, null);
                int[] searchSteps = new int[1];
                root.nonConflictingBranch.search(doc, true, clusterId, searchSteps);

                if (doc.selectedExpandNode != null) {
                    doc.selectedExpandNode.collectResults(results);
                }
            } catch(ExpandNodeException e) {
                System.err.println("Too many search steps!");
            }
        }

        doc.selectedOption = Option.add(doc, true, results.toArray(new Option[results.size()]));
    }


    private static int computeClusters(Document doc) {
        long v = Option.visitedCounter++;
        for(Option n: doc.bottom.children) {
            n.clusterId = -1;
        }

        int clusterIdCounter = 0;
        for(Option n: doc.bottom.children) {
            if(n.clusterId < 0) {
                markClusterUp(n, clusterIdCounter++, v);
            }
        }
        return clusterIdCounter;
    }


    private static void markClusterUp(Option n, int clusterId, long v) {
        if(n.clusterVisitedUp == v) return;
        n.clusterVisitedUp = v;

        for(Option cn: n.children) {
            if(!cn.inv) {
                markClusterUp(cn, clusterId, v);
            }
        }
        markClusterDown(n, clusterId, v);
    }


    private static void markClusterDown(Option n, int clusterId, long v) {
        if(n.isBottom() || n.clusterVisitedDown == v) return;
        n.clusterVisitedDown = v;

        assert n.clusterId < 0 || n.clusterId == clusterId;

        n.clusterId = clusterId;

        for(Option pn: n.parents) {
            markClusterDown(pn, clusterId, v);
        }
        markClusterUp(n, clusterId, v);
    }


    private void collectResults(Collection<Option> results) {
        results.add(refinement);
        if(selectedParent != null) selectedParent.collectResults(results);
    }


    // TODO: mark also selected nodes that consist of several expand nodes.
    private void markSelected(long v) {
        refinement.markSelected(v);
        if(selectedParent != null) selectedParent.markSelected(v);
    }


    private void search(Document doc, boolean branch, int clusterId, int[] searchSteps) {
        if(searchSteps[0] > MAX_SEARCH_STEPS) {
            throw new ExpandNodeException();
        }
        searchSteps[0]++;

        generateCandidate(doc, branch, clusterId);

        if(refinement == null) return;

        boolean f = doc.selectedMark == -1 || refinement.markedSelected != doc.selectedMark;

        nonConflictingBranch = new ExpandNode(this, this, excludedParent);
        nonConflictingBranch.search(doc, true, clusterId, searchSteps);

        if(f) {
            conflictingBranch = new ExpandNode(this, selectedParent, this);
            conflictingBranch.search(doc, false, clusterId, searchSteps);
        }
    }


    private void generateCandidate(Document doc, boolean branch, int clusterId) {
        upperBound = 0.0;
        accumulatedWeight = 0.0;

        TreeSet<Option> queue = new TreeSet(Option.SMALLEST_FIRST_COMPARATOR);

        long ubVisited = Option.visitedCounter++;
        doc.bottom.containedInUpperBound = ubVisited;
        queue.add(doc.bottom);

        while(!queue.isEmpty()) {
            Option n = queue.pollFirst();

            if(n.containedInUpperBound(ubVisited) && !coveredConflicting(n)) {
                n.containedInUpperBound = ubVisited;
                if(n.weight > 0.0) {
                    long v = Option.visitedCounter++;
                    Double naw = markCovered(v, n);
                    if(naw == null) {
                        parent.conflicts.add(n);
                        continue;
                    }

                    double aw = (selectedParent != null ? selectedParent.accumulatedWeight : 0.0) + naw;

                    upperBound += n.weight;

                    if (accumulatedWeight < aw && (branch || isConflicting(n)) && !isCovered(n.markedCovered)) {
                        accumulatedWeight = aw;
                        refinement = n;
                    }
                }
                for(Option c: n.children) {
                    if(!c.inv && c.clusterId == clusterId && c.isConflict < 0) {
                        queue.add(c);
                    }
                }
            }
        }

        if(doc.selectedExpandNode != null && upperBound < doc.selectedExpandNode.accumulatedWeight) {
            refinement = null;
        }

        if(refinement == null) return;

        visited = Option.visitedCounter++;
        accumulatedWeight = (selectedParent != null ? selectedParent.accumulatedWeight : 0.0) + markCovered(visited, refinement);;

        if(doc.selectedExpandNode == null || accumulatedWeight > doc.selectedExpandNode.accumulatedWeight) {
            doc.selectedExpandNode = this;
            doc.selectedMark = Option.visitedCounter++;
            markSelected(doc.selectedMark);
        }
    }


    public boolean coveredConflicting(Option n) {
        return excludedParent != null && (excludedParent.refinement == n || excludedParent.coveredConflicting(n));
    }


    private boolean isCovered(long g) {
        return selectedParent != null && (g == selectedParent.visited || selectedParent.isCovered(g));
    }


    private Double markCovered(long v, Option n) {
        if(n.visitedMarkCovered == v) return 0.0;
        n.visitedMarkCovered = v;

        if(isCovered(n.markedCovered)) return 0.0;

        n.markedCovered = v;

        if(n.isBottom()) {
            return n.weight;
        }

        double result = n.weight;

        for(Option p: n.parents) {
            Double r = markCovered(v, p);
            if(r == null) return null;
            result += r;
        }

        for(Option c: n.children) {
            if(c.inv || c.visitedMarkCovered == v) continue;

            if(!containedInSelectedBranch(v, c)) continue;

            if(c.isConflict >= 0) return null;

            c.markedCovered = v;

            Double r = markCovered(v, c);
            if(r == null) return null;
            result += r;
        }

        return result;
    }


    public boolean containedInSelectedBranch(long v, Option n) {
        for(Option p: n.parents) {
            if(p.markedCovered != v && !isCovered(p.markedCovered)) return false;
        }
        return true;
    }


    private boolean isConflicting(Option n) {
        return parent != null && (parent.checkForConflict(n) || (this == parent.conflictingBranch && parent.isConflicting(n)));
    }


    public boolean checkForConflict(Option n) {
        return conflicts.contains(n) || (selectedParent != null && selectedParent.checkForConflict(n));
    }


    public static class ExpandNodeException extends RuntimeException {

    }
}
