package org.aika.corpus;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.network.Iteration;
import org.aika.network.neuron.Activation;
import org.aika.utils.SetUtils;
import org.aika.network.neuron.Activation.Key;

import java.util.*;


/**
 * The <code>Option</code> class represents a node within the options lattice. Each option has a key uniquely identifying
 * this. Such a key consists of a option ids that are emitted by a <code>NegativeInputNode</code>. There are different
 * types of option nodes such as <code>ACTIVATION</code>, <code>CONFLICT</code>, <code>TOP</code> and <code>BOTTOM</code>.
 * Option nodes of type <code>ACTIVATION</code> contain a set of all activations associated with this option node.
 * Each option node maintains a weight which is accumulated and propagated towards the top of the option lattice.
 *
 *
 * @author Lukas Molzberger
 */
public class Option implements Comparable<Option> {

    public static final Option MIN = new Option(null, false, 0, 0, 0, 0, 0);
    public static final Option MAX = new Option(null, true, Integer.MAX_VALUE, Integer.MAX_VALUE, 0, 0, Integer.MAX_VALUE);

    public static final Comparator<Option> SIZE_COMPARATOR = new Comparator<Option>() {
        @Override
        public int compare(Option n1, Option n2) {
            int r = Integer.compare(n2.length, n1.length);
            if(r != 0) return r;
            return Integer.compare(n1.id, n2.id);
        }
    };

    public static final Comparator<Option> SMALLEST_FIRST_COMPARATOR = new Comparator<Option>() {
        @Override
        public int compare(Option n1, Option n2) {
            int r = Integer.compare(n1.length, n2.length);
            if(r != 0) return r;
            return Integer.compare(n1.id, n2.id);
        }
    };

    public final boolean inv;
    public final int primId;
    public final int id;
    public final int length;

    public final int minPos;
    public final int maxPos;

    public Option negation;

    private long visitedComputeRelations = -1;
    private long visitedOptimize = -1;
    private long visitedLinkRelations = -1;
    private long visitedLinkPrimitive = -1;
    private long visitedContains = -1;
    private long visitedCollect = -1;
    private long visitedCount = -1;
    public long visitedAccumulatedWeight = -1;
    private long visitedComputeWeights = -1;
    private long visitedCollectConflicts = -1;
    private long visitedExpandActivations = -1;
    private long visitedRemoveActivations = -1;
    public long visitedMarkCovered = -1;
    public long markedCovered = -1;
    public long containedInUpperBound = -1;
    public long markedSelected = -1;

    public int clusterId = -1;
    public long clusterVisitedUp = -1;
    public long clusterVisitedDown = -1;

    public static long visitedCounter = 0;

    public final Document doc;

    public boolean isRemoved;
    public int removedId;
    public static int removedIdCounter = 1;

    public ArrayList<Option> parents;
    public ArrayList<Option> children;

    public Map<Option, Boolean> cache = new HashMap<>();

    public int isConflict = -1;
    public Conflicts conflicts = new Conflicts();

    public NavigableMap<Key, Activation> activations;
    public NavigableSet<Activation> activationsComplete;

    public double weight;
    public double accumulatedWeight;

    public int refCount = 0;


    public enum Relation {
        EQUALS,
        CONTAINS,
        CONTAINED_IN;

        public boolean compare(Option a, Option b) {
            switch(this) {
                case EQUALS:
                    return a == b;
                case CONTAINS:
                    return a.contains(b);
                case CONTAINED_IN:
                    return b.contains(a);
                default:
                    return false;
            }
        }
    }


    public Option(Document doc, boolean inv, int primId, int id, int minPos, int maxPos, int length) {
        this.doc = doc;
        this.inv = inv;
        this.primId = primId;
        this.id = id;
        this.minPos = minPos;
        this.maxPos = maxPos;
        this.length = length;
        parents = new ArrayList<>();
        children = new ArrayList<>();
    }


    public void countRef() {
        if(isBottom() || isTop()) return;
        refCount++;
    }


    public void releaseRef() {
        if(isBottom() || isTop()) return;
        assert refCount > 0;
        refCount--;
        if(refCount == 0) {
            remove();
        }
    }


    public List<Option> select(Relation or) {
        List<Option> results = new ArrayList<>();
        switch(or) {
            case EQUALS:
                results.add(this);
                break;
            case CONTAINS:
                collect(results, false, true, visitedCounter++);
                break;
            case CONTAINED_IN:
                collect(results, true, true, visitedCounter++);
                break;
            default:
        }
        return results;
    }


    public void collect(List<Option> results, boolean dir, boolean includeFirst, long v) {
        if(visitedCollect == v) return;
        visitedCollect = v;

        if(includeFirst) {
            results.add(this);
        }

        for(Option r: dir ? parents : children) {
            r.collect(results, dir, true, v);
        }
    }


    public Option clonePrimitive(Iteration t) {
        assert primId != -1;
        Option no = addPrimitive(doc, minPos);

        Conflicts.copy(t, this, no);
        return no;
    }


    public void removePrimitive(Iteration t) {
        assert primId >= 0;

        conflicts.removeAll();
    }


    void expandActivationsRecursiveStep(Iteration t, Option conflict, long v) {
        if(v == visitedExpandActivations) return;
        visitedExpandActivations = v;

        for (Activation act : getActivationsComplete()) {
            act.key.n.propagateAddedActivation(t, act, act.key.pos, conflict);
        }

        for(Option p: parents) {
            p.expandActivationsRecursiveStep(t, conflict, v);
        }
    }


    void removeActivationsRecursiveStep(Iteration t, Option conflict, long v) {
        if(v == visitedRemoveActivations) return;
        visitedRemoveActivations = v;

        for (Activation act : getActivationsComplete()) {
            if(act.key.o.contains(conflict)) {
                act.key.n.removeActivationAndPropagate(t, false, act, act.key.pos);
            }
        }

        if(children != null) {
            for (Option c : children) {
                if (!c.isRemoved) {
                    c.removeActivationsRecursiveStep(t, conflict, v);
                }
            }
        }
    }


    public Collection<Activation> getActivationsComplete() {
        return activationsComplete != null ? activationsComplete : SetUtils.EMPTY_SET;
    }


    public static Option add(Document doc, boolean nonConflicting, Option... input) {
        if(input.length == 0) return doc.bottom;

        ArrayList<Option> in = new ArrayList<>();
        int minPos = Integer.MAX_VALUE;
        int maxPos = 0;
        for(int i = 0; i < input.length; i++) {
            Option n = input[i];
            if(n == null) continue;

            assert doc == n.doc;
            assert !n.inv;
            assert !n.isRemoved;

            boolean f = true;
            for(int j = 0; j < input.length; j++) {
                Option x = input[j];
                if(x == null) continue;

                if(i != j && x.contains(n) && (x != n || i < j)) {
                    f = false;
                    break;
                }
            }
            if(f) {
                in.add(n);
                minPos = Math.min(minPos, n.minPos);
                maxPos = Math.max(maxPos, n.maxPos);
            }
        }

        if(in.size() == 1) {
            Option n = in.get(0);
            if(nonConflicting && n.isConflict >= 0) return null;
            n.countRef();
            return n;
        }

        ArrayList<Option> children = new ArrayList<>();
        computeRelations(children, false, in, in, visitedCounter++);

        // TODO: check if this is correct
        if(children.size() == 1) {
            Option n = children.get(0);
            if(!n.inv) {
                if(nonConflicting && n.isConflict >= 0) return null;
                n.countRef();
                return n;
            }
        } else if(children.size() == 0) {
            children.add(doc.top);
        }

        ArrayList<Option> parents = new ArrayList<>();
        computeRelations(parents, true, in, children, visitedCounter++);

        if(nonConflicting) {
            for (Option p : parents) {
                if (p.isConflict >= 0) {
                    return null;
                }
            }
        }

        Option n = new Option(doc, false, -1, doc.optionIdCounter++, minPos, maxPos, computeLength(parents));
        Option nn = new Option(doc, true, -1, doc.optionIdCounter++, n.minPos, n.maxPos, Integer.MAX_VALUE - n.length);
        n.negation = nn;
        nn.negation = n;

        n.linkRelations(parents, children, visitedCounter++);

        n.countRef();

        return n;
    }


    public static int computeLength(List<Option> parents) {
        ArrayList<Option> count = new ArrayList<>();
        computeLengthRecursiveStep(count, new ArrayList<>(), parents);

        int l = 0;
        for(Option n: count) {
            l += n.length;
        }
        return l;
    }


    private static void computeLengthRecursiveStep(List<Option> count, List<Option> parentOverlapping, List<Option> parents) {
        Option[] sortedParents = new Option[parents.size()];
        parents.toArray(sortedParents);
        Arrays.sort(sortedParents, SIZE_COMPARATOR);

        ArrayList<Option> overlapping = new ArrayList<>(parentOverlapping);
        for(Option p: sortedParents) {
            if(!p.containedIn(parentOverlapping)) {
                List<Option> no = p.filterOutside(overlapping);
                if(no.isEmpty()) {
                    count.add(p);
                } else {
                    computeLengthRecursiveStep(count, no, p.parents);
                }
                overlapping.add(p);
            }
        }
    }


    public List<Option> filterOutside(List<Option> options) {
        ArrayList<Option> results = new ArrayList<>();
        for (Option n : options) {
            if (!negation.contains(n)) {
                results.add(n);
            }
        }
        return results;
    }


    private void computePrimitiveRelations(List<Option> results, long v) {
        if(v == visitedLinkPrimitive) return;
        visitedLinkPrimitive = v;

        boolean f = false;
        for(Option p: parents) {
            if(p.inv) {
                f = true;
                p.computePrimitiveRelations(results, v);
            }
        }

        if(!f) {
            results.add(this);
        }
    }


    private static void computeRelations(List<Option> results, boolean dir, List<Option> input, List<Option> start, long v) {
        Option bestN = null;
        int bestLength = dir ? Integer.MAX_VALUE : 0;
        for(Option s: start) {
            if(dir == bestLength >= s.length) {
                bestN = s;
                bestLength = s.length;
            }
        }
        bestN.computeRelationsRecursiveStep(results, dir, input, v);
    }


    private void computeRelationsRecursiveStep(List<Option> results, boolean dir, List<Option> input, long v) {
        if(v == visitedComputeRelations) return;
        visitedComputeRelations = v;

        for(Option r: dir ? parents : children) {
            if((!dir && r.containsAll(input)) || (dir && r.containedIn(input))) {
                r.optimize(results, dir, input, v);
            } else if((!dir && !r.negation.containedIn(input)) || (dir && !r.outsideOfAll(input))) {
                r.computeRelationsRecursiveStep(results, dir, input, v);
            }
        }
    }


    private void optimize(List<Option> results, boolean dir, List<Option> input, long v) {
        if(v == visitedOptimize) return;
        visitedOptimize = v;

        boolean f = false;
        for(Option r: dir ? children : parents) {
            if(v != r.visitedComputeRelations) {
                if ((!dir && r.containsAll(input)) || (dir && r.containedIn(input))) {
                    f = true;
                    r.optimize(results, dir, input, v);
                }
            }
        }

        if(!f) {
            results.add(this);
        }
    }


    private void linkRelations(List<Option> pSet, List<Option> cSet, long v) {
        for(Option p: pSet) {
            addLink(p, this);
        }
        for(Option c: cSet) {
            c.visitedLinkRelations = v;
            addLink(this, c);
        }

        for(Option p: pSet) {
            ArrayList<Option> tmp = new ArrayList<>();
            for(Option c: p.children) {
                if(c.visitedLinkRelations == v) {
                    tmp.add(c);
                }
            }

            for(Option c: tmp) {
                removeLink(p, c);
            }
        }
    }


    public boolean containedIn(Collection<Option> input) {
        if(inv) return false;
        if(containedInAny(input)) {
            return true;
        } else if(outsideOfAll(input)) {
            return false;
        }
        for(Option p: parents) {
            if(!p.containedIn(input)) return false;
        }
        return true;
    }


    public static void addLink(Option a, Option b) {
        a.children.add(b);
        b.parents.add(a);
        if(a != b.negation) {
            b.negation.children.add(a.negation);
            a.negation.parents.add(b.negation);
        }
    }


    public static void removeLink(Option a, Option b) {
        a.children.remove(b);
        b.parents.remove(a);
        if(a != b.negation) {
            b.negation.children.remove(a.negation);
            a.negation.parents.remove(b.negation);
        }
    }


    public static Option addPrimitive(Document doc, int pos) {
        int primId = doc.primOptionIdCounter++;
        Option n = new Option(doc, false, primId, doc.optionIdCounter++, pos, pos, 1);
        Option nn = new Option(doc, true, primId, doc.optionIdCounter++, pos, pos, Integer.MAX_VALUE - 1);
        nn.negation = n;
        n.negation = nn;

        ArrayList<Option> children = new ArrayList<>();
        doc.top.computePrimitiveRelations(children, visitedCounter++);

        n.linkRelations(Arrays.asList(doc.bottom), children, visitedCounter++);

        n.countRef();

        return n;
    }


    private void remove() {
        assert !inv;
        assert !isRemoved;
        isRemoved = true;
        removedId = removedIdCounter++;

        for(Option p: parents) {
            p.children.remove(this);
            if(p != negation) {
                p.negation.parents.remove(negation);
            }
        }
        for(Option c: children) {
            c.parents.remove(this);
            if(this != c.negation) {
                c.negation.children.remove(negation);
            }
        }
        for(Option p: parents) {
            for(Option c: children) {
                if(!c.isLinked(p, visitedCounter++)) {
                    addLink(p, c);
                }
            }
        }

        parents = null;
        children = null;
        conflicts = null;
        negation.negation = null;
        negation = null;
    }


    public void count() {
        countRecursiveStep(this, visitedCounter++);
    }


    private void countRecursiveStep(Option so, long v) {
        if(v == visitedCount) return;
        visitedCount = v;

        for(Activation act: getActivationsComplete()) {
            act.key.n.countActivation(act.key, so);
        }

        for(Option p: parents) {
            p.countRecursiveStep(so, v);
        }
    }


    public void computeWeights(long v) {
        if(v == visitedComputeWeights) return;
        visitedComputeWeights = v;

        weight = 0.0;
        for (Activation act : getActivationsComplete()) {
            if(act.key.n.neuron != null) {
                weight += act.weight;
            }
        }

        for(Option p: parents) {
            p.computeWeights(v);
        }
    }


    public boolean isTop() {
        return length == Integer.MAX_VALUE;
    }


    public boolean isBottom() {
        return length == 0;
    }


    public boolean outsideOfAll(Collection<Option> input) {
        for(Option n: input) {
            if(!n.negation.contains(this)) return false;
        }
        return true;
    }


    public boolean containedInAny(Collection<Option> input) {
        for(Option n: input) {
            if(n.contains(this)) return true;
        }
        return false;
    }


    public boolean containsAll(List<Option> input) {
        for(Option n: input) {
            if(!contains(n)) return false;
        }
        return true;
    }


    public boolean contains(Option n) {
        return contains(n, true, visitedCounter++);
    }


    public boolean contains(boolean dir, Option n) {
        boolean r;
        if(!dir) {
            r = contains(n, true, visitedCounter++);
        } else {
            r = n.contains(this, true, visitedCounter++);
        }
        return r;
    }


    private boolean contains(Option n, boolean start, long v) {
        assert visitedContains <= v;
        assert !isRemoved;
        assert !n.isRemoved;

        if(!inv && n.inv) {
            return false;
        }

        /*
        wenn geprüft werden soll ob a (inv) b (!inv) enthält, kann mit b.contains(a.negation) getestet werden ob der test überhaupt sinn macht.
        */

        if(this == n || isTop() || n.isBottom()) {
            return true;
        }

        visitedContains = v;
        if(length < n.length) return false;

        if(maxPos < n.minPos || n.maxPos < minPos) {
            return inv != n.inv;
        }

        Boolean cv = cache.get(n);
        if(cv != null) {
            return cv;
        }

        boolean result = false;
        for(Option p: parents) {
            if(p.visitedContains != v && p.contains(n, false, v)) {
                result = true;
                break;
            }
        }
        if(start) {
            cache.put(n, result);
        }
        return result;
    }


    private boolean isLinked(Option n, long v) {
        assert visitedContains <= v;
        assert !isRemoved;
        assert !n.isRemoved;

        if(this == n) {
            return true;
        }

        visitedContains = v;
        if(length < n.length) return false;

        for(Option p: parents) {
            if(p.visitedContains != v && p.isLinked(n, v)) return true;
        }
        return false;
    }


    private void collectPrimitiveIds(Set<Integer> results, long v) {
        if(v == visitedCollect) return;
        visitedCollect = v;

        if(primId >= 0) {
            results.add(primId);
        } else {
            for(Option n: inv ? children : parents) {
                n.collectPrimitiveIds(results, v);
            }
        }
    }


    public boolean containedInUpperBound(long v) {
        for(Option p: parents) {
            if(p.containedInUpperBound != v) return false;
        }
        return true;
    }


    public void collectConflicts(Set<Option> conflicts, long v) {
        if(visitedCollectConflicts == v) return;
        visitedCollectConflicts = v;

        for(Option n: parents) {
            if(isConflict >= 0) {
                conflicts.add(n);
            }
            n.collectConflicts(conflicts, v);
        }
    }


    public static Option create(Document doc, Integer... ids) {
        HashSet<Integer> tmp = new HashSet<>(Arrays.asList(ids));

        ArrayList<Option> primOptions = new ArrayList<>();
        for(Option p: doc.bottom.children) {
            if(tmp.contains(p.primId)) {
                primOptions.add(p);
            }
        }

        return Option.add(doc, false, primOptions.toArray(new Option[primOptions.size()]));
    }


    public double computeAccumulatedWeight(long v) {
        if(v == visitedAccumulatedWeight) return 0.0;
        visitedAccumulatedWeight = v;

        double accWeight = weight;
        for(Option p: parents) {
            accWeight += p.computeAccumulatedWeight(v);
        }
        return accWeight;
    }


    public void markSelected(long v) {
        if(markedSelected == v) return;
        markedSelected = v;

        for(Option p: parents) {
            p.markSelected(v);
        }
    }


    public String toString() {
        SortedSet<Integer> ids = new TreeSet<>();
        collectPrimitiveIds(ids, visitedCounter++);

        StringBuilder sb = new StringBuilder();
        sb.append("(");
        if(inv) sb.append("!");
        boolean first = true;
        for(Integer id: ids) {
            if(!first) sb.append(",");
            first = false;
            sb.append(id);
        }

        sb.append(")");
        return sb.toString();
    }


    @Override
    public int compareTo(Option n) {
        int r = Integer.compare(length, n.length);
        if(r != 0) return r;
        return Integer.compare(id, n.id);
    }


    public static int compare(Option oa, Option ob) {
        if(oa == ob) return 0;
        if(oa == null && ob != null) return -1;
        if(oa != null && ob == null) return 1;
        return oa.compareTo(ob);
    }
}
