package org.aika.network.neuron;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Document;
import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation.Key;
import org.aika.network.neuron.recurrent.ClockTerminationNode;
import org.aika.network.neuron.simple.lattice.AndNode;
import org.aika.network.neuron.simple.lattice.InputNode;
import org.aika.network.neuron.simple.lattice.NegativeInputNode;

import java.util.*;

/**
 *
 * @author Lukas Molzberger
 */
public abstract class Node implements Comparable<Node> {
    public static int minFrequency = 5;

    public static final Node MIN_NODE = new DummyNode(0);
    public static final Node MAX_NODE = new DummyNode(Integer.MAX_VALUE);

    public static int currentNodeId = 0;
    public int id;

    public int level;

    public double weight = 0.0;
    public int frequency;

    // False : Default activations are counted as one if they are adjoined.
    // True: Every activation counts.
    public boolean countingMode;

    public boolean isPredefined;
    public boolean isBlocked;

    public boolean isRemoved;
    public int isRemovedId;
    public static int isRemovedIdCounter = 0;

    public boolean frequencyHasChanged = true;
    public int n;

    public NavigableMap<Key, Activation> activations = new TreeMap<>(new Comparator<Key>() {
        @Override
        public int compare(Key k1, Key k2) {
            int r;
            if(k1.pos != k2.pos) {
                if(k1.pos == null) return -1;
                if(k2.pos == null) return 1;
                r = k1.pos.compareTo(k2.pos);
                if (r != 0) return r;
            }
            r = Integer.compare(k1.rid, k2.rid);
            if(r != 0) return r;
            if(k1.o != k2.o) {
                if (k1.o == null) return -1;
                if (k2.o == null) return 1;
                r = k1.o.compareTo(k2.o);
                if (r != 0) return r;
            }
            r = Integer.compare(k1.fired, k2.fired);
            if(r != 0) return r;
            return Integer.compare(k1.id, k2.id);
        }
    });

    public NavigableMap<ChangeKey, Set<Activation>[]> added = new TreeMap<>();
    public NavigableMap<ChangeKey, RemovedEntry> removed = new TreeMap<>();
    public boolean isQueued = false;


    public Neuron neuron = null;


    public static long visitedCounter = 0;


    public abstract boolean isAllowedOption(Option n, Activation act, long v);


    public abstract boolean isNegative();


    public abstract boolean containsNegative();


    public abstract void cleanup(Model m);


    public abstract void initActivation(Iteration t, Activation act);


    public abstract void deleteActivation(Iteration t, Activation act);


    public abstract double computeForwardWeight(Activation act);


    public abstract double getNodeWeight(Activation act);


    public abstract double computeSynapseWeightSum(Neuron n);


    public abstract String logicToString();

    private Node() {

    }

    public Node(Model m, int level) {
        id = currentNodeId++;
        this.level = level;
        m.allNodes.add(this);
    }


    public Option retrieveInitialOption(Range r, Integer rid, Option o) {
        for(ChangeKey ck: added.subMap(new ChangeKey(rid, o, 0, 0, r, null), new ChangeKey(rid, o, Integer.MAX_VALUE, Integer.MAX_VALUE, r, Option.MAX)).keySet()) {
            return ck.initOption;
        }
        return null;
    }


    public void countActivation(Key ak, Option so) {
        int delta = 1;
        if (!countingMode) {
            for(Activation secondAct: Activation.select(this, 0, ak.pos, Range.Relation.OVERLAPS_AFTER, so, Option.Relation.CONTAINS, null, null, null, false)) {
                if(secondAct.key != ak) {
                    delta += -1;
                    break;
                }
            }
        }

        frequency += delta;
        frequencyHasChanged = true;
    }


    public Activation addActivationInternal(Iteration t, Key ak, Option initOption, Option newOption, Set<Activation> inputActs, Set<Activation> directInputActs, Set<Activation> outputActs, Set<Activation> directOutputActs) {
        Activation act = new Activation(ak);
        act.initialOption = initOption;
        act.newOption = newOption;
        act.completeOption = newOption != null ? Option.add(t.doc, false, ak.o, newOption) : ak.o;

        act.link(inputActs, directInputActs, outputActs, directOutputActs);

        for (Activation oa : Activation.select(this, ak.rid, ak.pos, Range.Relation.OVERLAPS, null, null, null, null, null, true)) {
            if (oa.shadows(act)) {
                oa.addShadows(act);
            }
            if (act.shadows(oa)) {
                act.addShadows(oa);
            }
        }

        initActivation(t, act);

        act.register(t);

        return act;
    }


    private void processChanges(Iteration t) {
        NavigableMap<ChangeKey, Set<Activation>[]> tmpAdded = added;
        NavigableMap<ChangeKey, RemovedEntry> tmpRemoved = removed;

        added = new TreeMap<>();
        removed = new TreeMap<>();

        if(!(this instanceof ClockTerminationNode)) {
            for (Iterator<Map.Entry<ChangeKey, RemovedEntry>> it = tmpRemoved.entrySet().iterator(); it.hasNext(); ) {
                ChangeKey ckr = it.next().getKey();
                boolean remove = false;
                for (ChangeKey cka : tmpAdded.keySet()) {
                    if (cka.o == ckr.o && cka.rid == ckr.rid && cka.fired == ckr.fired && cka.id == ckr.id && cka.r.contains(ckr.r))
                        remove = true;
                }
                if (remove) it.remove();
            }
        }


        for(RemovedEntry re: tmpRemoved.values()) {
            processRemovedActivations(t, re.act, re.removedRange);
        }

        for(Map.Entry<ChangeKey, Set<Activation>[]> me: tmpAdded.entrySet()) {
            processAddedActivations(t, me.getKey(), me.getValue()[0], me.getValue()[1]);
        }
    }


    public static void addActivationAndPropagate(Iteration t, boolean processFirst, Key ak, Range addedRange, Option initOption, Set<Activation> inputActs, Set<Activation> directInputActs) {
        if(addedRange.isEmpty()) return;

        ChangeKey ck = new ChangeKey(ak.rid, ak.o, ak.fired, ak.id, addedRange, initOption);
        Set[] iActs = ak.n.added.get(ck);
        if(iActs == null) {
            iActs = new Set[]{new TreeSet<>(), new TreeSet<>()};
            ak.n.added.put(ck, iActs);
        }
        iActs[0].addAll(inputActs);
        iActs[1].addAll(directInputActs);
        t.queue.add(ak.n, processFirst);
    }


    public void processAddedActivations(Iteration t, ChangeKey ck, Set<Activation> inputActs, Set<Activation> directInputActs) {
        Set<Activation> iActs = new TreeSet<>(inputActs);
        Set<Activation> diActs = new TreeSet<>(directInputActs);
        Set<Activation> oActs = new TreeSet<>();
        Set<Activation> doActs = new TreeSet<>();

        Option no = computeNewOption(t.doc, ck.initOption, directInputActs);
        Collection<Activation> oldActs = Activation.select(this, ck.rid, ck.r, countingMode ? Range.Relation.OVERLAPS : Range.Relation.OVERLAPS_INCLUDE_ADJOINED, ck.o, Option.Relation.EQUALS, no, ck.fired, ck.id, true);
        for(Activation oldAct: oldActs) {
            iActs.addAll(oldAct.inputs);
            diActs.addAll(oldAct.directInputs);
            oActs.addAll(oldAct.outputs);
            doActs.addAll(oldAct.directOutputs);
        }

        Range r = extractUncoveredRange(t.doc, ck.r, oldActs);

        for(Activation oldAct: oldActs) {
            removeActivationInternal(t, oldAct.key);
        }

        Activation act = addActivationInternal(t, computeNewActivationKey(t.doc, ck, oldActs), ck.initOption, no, iActs, diActs, oActs, doActs);

        Range outputRange = Range.create(t.doc, Range.intersection(r.getSegments(), act.getNonShadowedRange()));
        if(!outputRange.isEmpty()) {
            propagateAddedActivation(t, act, outputRange, null);
        }

        for(Activation shadAct: act.shadows.values()) {
            Range shadowedRange = shadAct.key.pos.intersection(act.key.pos);
            for(Activation sbAct: shadAct.shadowedBy.values()) {
                if(sbAct != act) {
                    shadowedRange = shadowedRange.complement(sbAct.key.pos);
                }
            }
            propagateRemovedActivation(t, shadAct, shadowedRange);
        }
    }


    private Option computeNewOption(Document doc, Option initOption, Set<Activation> directInputActs) {
        if(!(this instanceof AndNode || this instanceof ClockTerminationNode)) return null; // TODO!
        ArrayList<Option> tmp = new ArrayList<>();
        if(initOption != null) {
            tmp.add(initOption);
        }
        for(Activation iAct: directInputActs) {
            if(iAct.newOption != null) {
                tmp.add(iAct.newOption);
            }
        }

        return Option.add(doc, false, tmp.toArray(new Option[tmp.size()]));
    }


    private Key computeNewActivationKey(Document doc, ChangeKey ck, Collection<Activation> acts) {
        int b = Integer.MAX_VALUE;
        int e = 0;
        for(Activation act: acts) {
            b = Math.min(b, act.key.pos.getBegin());
            e = Math.max(e, act.key.pos.getEnd());
        }

        return new Key(this, Range.create(doc, Math.min(ck.r.getBegin(), b), Math.max(ck.r.getEnd(), e)), ck.rid, ck.o, ck.fired, ck.id);
    }


    private Range extractUncoveredRange(Document doc, Range newRange, Collection<Activation> acts) {
        List<Range> r = new ArrayList<>();
        for(Activation act: acts) {
            r.add(act.key.pos);
        }

        return newRange.complement(Range.add(doc, r.toArray(new Range[r.size()])));
    }


    public Activation removeActivationInternal(Iteration t, Key ak) {
        Activation act = activations.get(ak);
        if(act == null) return null;

        act.unlink();

        deleteActivation(t, act);
        act.unregister(t);
        act.removedId = Activation.removedIdCounter++;
        act.isRemoved = true;

        if(!(this instanceof NegativeInputNode)) {
            for (Activation shadAct : act.shadowedBy.values()) {
                shadAct.shadows.remove(act.key);
            }

            for (Activation shadAct : act.shadows.values()) {
                shadAct.shadowedBy.remove(act.key);
            }
        }
        return act;
    }


    public static void removeActivationAndPropagate(Iteration t, boolean processFirst, Activation act, Range removedRange) {
        if(removedRange.isEmpty()) return;

//        assert act != null; // TODO: check
        if(act == null) return;

        ChangeKey ck = new ChangeKey(act.key.rid, act.key.o, act.key.fired, act.key.id, act.key.pos, null);
        RemovedEntry re = act.key.n.removed.get(ck);
        if(re == null) {
            re = new RemovedEntry();
            re.act = act;
            re.removedRange = removedRange;
            act.key.n.removed.put(ck, re);
            t.queue.add(act.key.n, processFirst);
        } else {
            assert re.act == act;
            re.removedRange = Range.add(t.doc, re.removedRange, removedRange);
        }
    }


    public void processRemovedActivations(Iteration t, Activation act, Range removedRange) {
        if(act.isRemoved) return;

        Key ak = act.key;
        Range remainingRange = ak.pos.complement(removedRange);

        removeActivationInternal(t, ak);

        for(int[] seg: remainingRange.getSegments()) {
            Range r = Range.create(t.doc, seg);
            addActivationInternal(t, new Key(this, r, ak.rid, ak.o, ak.fired, ak.id), act.initialOption, act.newOption, filterInputsAndOutputs(act, r, act.inputs), filterInputsAndOutputs(act, r, act.directInputs), filterInputsAndOutputs(act, r, act.outputs), filterInputsAndOutputs(act, r, act.directOutputs));
            // TODO: restore output links
        }

        Range outputRange = Range.create(t.doc, Range.intersection(removedRange.getSegments(), act.getNonShadowedRange()));
        if(!outputRange.isEmpty()) {
            propagateRemovedActivation(t, act, outputRange);
        }

        for(Activation shadAct: act.shadows.values()) {
            Range shadowedRange = shadAct.key.pos.intersection(removedRange);
            for(Activation sbAct: shadAct.shadowedBy.values()) {
                if(sbAct != act) {
                    shadowedRange = shadowedRange.complement(sbAct.key.pos);
                }
            }
            propagateAddedActivation(t, shadAct, shadowedRange, null);
        }

        act.key.releaseRef();
    }


    private Set<Activation> filterInputsAndOutputs(Activation act, Range r, Set<Activation> io) {
        Set<Activation> results = new TreeSet<>();
        for(Activation actIO: io) {
            if(actIO.key.rid != act.key.rid || Range.overlaps(r, actIO.key.pos, false, false)) {
                results.add(actIO);
            }
        }
        return results;
    }


    public void propagateAddedActivation(Iteration t, Activation act, Range addedRange, Option conflict) {
        if(neuron != null && (conflict == null || (act.initialOption != null && act.completeOption.contains(conflict)))) {
            neuron.propagateAddedActivation(t, act, addedRange);
        }
    }


    public void propagateRemovedActivation(Iteration t, Activation act, Range removedRange) {
        if(neuron != null) {
            neuron.propagateRemovedActivation(t, act, removedRange);
        }
    }


    public Collection<Activation> getActivations() {
        return activations.values();
    }


    public Collection<Activation> getSelectedActivations(Document doc) {
        ArrayList<Activation> results = new ArrayList<>();
        for(Activation act: activations.values()) {
            if(doc.isSelected(act)) {
                results.add(act);
            }
        }
        return results;
    }

    public Activation getFirstActivation() {
        if(activations.isEmpty()) return null;
        return activations.firstEntry().getValue();
    }


    public void clearActivations() {
        for(Activation act: activations.values()) {
            act.unlink();
            act.key.releaseRef();
        }
        activations.clear();
    }


    public boolean isFrequentOrPredefined(int freq) {
        if(isPredefined) return true;
        return freq >= minFrequency;
    }


    public boolean isFrequentOrPredefined() {
        if(isPredefined) return true;
        return frequency >= minFrequency;
    }


    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(toSimpleString());
        sb.append(" - ");
        sb.append(logicToString());
        return sb.toString();
    }


    public String toSimpleString() {
        StringBuilder sb = new StringBuilder();
        sb.append(id);
        if(neuron != null && neuron.label != null) {
            sb.append(" ");
            sb.append(neuron.label);
        }
        sb.append(" (Freq:");
        sb.append(frequency);
        sb.append(", W:");
        sb.append(weight);

        if(this instanceof AndNode) {
            AndNode an = (AndNode) this;
            sb.append(", MPR:");
            sb.append(an.minPRelevance);
        }
        sb.append(", N:");
        sb.append(n);
        sb.append(")");
        return sb.toString();
    }


    @Override
    public int compareTo(Node n) {
        if(id < n.id) return -1;
        else if(id > n.id) return 1;
        else return 0;
    }


    public static int compare(Node a, Node b) {
        if(a == b) return 0;
        if(a == null && b != null) return -1;
        if(a != null && b == null) return 1;
        return a.compareTo(b);
    }


    public static class Queue {

        public final TreeSet<Entry> queue = new TreeSet<>();

        private long queueIdCounter = 0;


        public void add(Node n, boolean processFirst) {
            if(!n.isQueued) {
                n.isQueued = true;
                queue.add(new Entry(processFirst && !queue.isEmpty() ? queue.first().queueId : queueIdCounter++, n.level, n));
            }
        }


        public void processChanges(Iteration t) {
            while(!queue.isEmpty()) {
                Entry e = queue.pollFirst();
                e.n.isQueued = false;
                e.n.processChanges(t);
            }
        }


        private static class Entry implements Comparable<Entry> {
            long queueId;
            int level;
            Node n;


            public Entry(long queueId, int level, Node n) {
                this.queueId = queueId;
                this.level = level;
                this.n = n;
            }


            @Override
            public int compareTo(Entry k) {
                int r = Long.compare(queueId, k.queueId);
                if(r != 0) return r;
                r = Integer.compare(level, k.level);
                if(r != 0) return r;
                return n.compareTo(k.n);
            }
        }
    }


    private static class ChangeKey implements Comparable<ChangeKey> {
        int rid;
        Option o;
        int fired;
        int id;
        Range r;
        Option initOption;

        public ChangeKey(int rid, Option o, int fired, int id, Range r, Option initOption) {
            this.rid = rid;
            this.o = o;
            this.fired = fired;
            this.id = id;
            this.r = r;
            this.initOption = initOption;
        }

        @Override
        public int compareTo(ChangeKey ck) {
            int c = r.compareTo(ck.r);
            if(c != 0) return c;
            c = Integer.compare(ck.rid, rid);
            if(c != 0) return c;
            c = o.compareTo(ck.o);
            if(c != 0) return c;
            c = Integer.compare(fired, ck.fired);
            if(c != 0) return c;
            c = Integer.compare(id, ck.id);
            if(c != 0) return c;
            return Option.compare(initOption, ck.initOption);
        }
    }


    private static class RemovedEntry {
        Activation act;
        Range removedRange;
    }


    private static class DummyNode extends Node {

        public DummyNode(int id) {

            this.id = id;
        }

        @Override
        public boolean isAllowedOption(Option n, Activation act, long v) {
            return false;
        }

        @Override
        public boolean isNegative() {
            return false;
        }

        @Override
        public boolean containsNegative() {
            return false;
        }

        @Override
        public void cleanup(Model m) {

        }

        @Override
        public void initActivation(Iteration t, Activation act) {

        }

        @Override
        public void deleteActivation(Iteration t, Activation act) {

        }

        @Override
        public double computeForwardWeight(Activation act) {
            return 0;
        }

        @Override
        public double getNodeWeight(Activation act) {
            return 0;
        }

        @Override
        public double computeSynapseWeightSum(Neuron n) {
            return 0;
        }

        @Override
        public String logicToString() {
            return null;
        }
    }
}
