package org.aika.network.neuron.recurrent;


/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


import org.aika.corpus.Conflicts;
import org.aika.corpus.Document;
import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Node;
import org.aika.network.neuron.simple.lattice.AndNode;
import org.aika.utils.SetUtils;

import java.util.*;


/**
 *
 * @author Lukas Molzberger
 */
public class ClockTerminationNode extends RecurrentNode {

    public ClockNode clockNode;
    public TerminationNode terminationNode;

    public Map<InputNode, OutputNode> outputNodes = new TreeMap<>();
    public Map<InputNode, OutputNode> outputNodesWithinDocument = new TreeMap<>();


    @Override
    public double computeForwardWeight(Activation act) {
        return 0;
    }


    @Override
    public double getNodeWeight(Activation act) {
        return weight;
    }


    public ClockTerminationNode(Model m, boolean direction) {
        super(m, 1);
        countingMode = true;
        this.direction = direction;
    }


    public Collection<OutputNode> getChildren() {
        return outputNodes.values();
    }




    private void addActivationFromClockOrTerminationNode(Iteration t, Activation inputAct) {
        Set<Activation> tmp = new TreeSet<>();
        for(Activation act: Activation.select(this, null, inputAct.key.pos, Range.Relation.OVERLAPS, null, null, null, null, null, false)) {
            if(inputAct.key.pos.getEnd(((RecurrentNode) inputAct.key.n).direction) != act.key.pos.getEnd(direction)) {
                if (act.key.o.contains(inputAct.key.o)) {
                    tmp.add(act);
                } else if (inputAct.key.o.contains(act.key.o)) {
                    boolean covered = false;
                    for (Iterator<Activation> it = tmp.iterator(); it.hasNext(); ) {
                        Activation tAct = it.next();
                        if (act.key.o.contains(tAct.key.o)) it.remove();
                        else if (tAct.key.o.contains(act.key.o)) covered = true;
                    }
                    if (!covered) {
                        tmp.add(act);
                    }
                }
            }
        }

        for(Activation act: tmp) {
            Set<Activation> inputs = new TreeSet<>();

            addInput(inputs, getBeginSignal(act));
            addInput(inputs, getPreviousAct(act));
            inputs.add(inputAct);

            Range r = Range.create(t.doc, act.key.pos.getBegin(direction), inputAct.key.pos.getEnd(((RecurrentNode) inputAct.key.n).direction));

            if(r != t.doc.bottomRange) {
                addActivationAndPropagate(t, true, new Activation.Key(this, r, act.key.rid, act.key.o, 0), r, act.initialOption, inputs, inputs);
            }
            removeActivationAndPropagate(t, true, act, act.key.pos);
        }

        if(inputAct.key.n instanceof TerminationNode) {
            addNextSegment(t, inputAct.key.pos.getEnd(((RecurrentNode) inputAct.key.n).direction), 0, inputAct.key.o, inputAct.newOption, inputAct);
        }
    }


    private static void addInput(Set<Activation> inputs, Activation iAct) {
        if(iAct != null) {
            inputs.add(iAct);
        }
    }


    private void removeActivationFromClockOrTerminationNode(Iteration t, Activation inputAct) {
        for(Activation act: inputAct.outputs) {
            if(act.key.pos.getEnd(direction) == inputAct.key.pos.getEnd(((RecurrentNode) inputAct.key.n).direction)) {
                addNextSegment(t, act.key.pos.getBegin(direction), act.key.rid, act.key.o, act.newOption, act.inputs.toArray(new Activation[act.inputs.size()]));
                removeActivationAndPropagate(t, true, act, act.key.pos);
            } else if(act.key.pos.getBegin(direction) == inputAct.key.pos.getEnd(((RecurrentNode) inputAct.key.n).direction) && inputAct.key.o == act.key.o && inputAct.key.n == terminationNode) {
                removeActivationAndPropagate(t, true, act, act.key.pos);
            }
        }
    }


    private void addActivationCTNode(Iteration t, Activation inputAct) {
        if(getTerminationSignal(inputAct) != null) return;

        Activation beginSignal = getEndSignal(inputAct);
        if(beginSignal == null) return;

        if(inputAct.key.o.contains(beginSignal.key.o)) {
            addNextSegment(t, inputAct.key.pos.getEnd(direction), inputAct.key.rid + 1, inputAct.key.o, inputAct.newOption, inputAct, beginSignal);
        } else {
            boolean covered = false;
            for(Activation altPathAct: Activation.select(this, null, inputAct.key.pos, direction ? Range.Relation.BEGIN_EQUALS : Range.Relation.END_EQUALS, inputAct.key.o, Option.Relation.CONTAINS, null, null, null, true)) {
                if(inputAct != altPathAct && beginSignal.key.o.contains(altPathAct.key.o)) covered = true;
            }

            Option no = null;
            if(!covered) {
                no = Option.addPrimitive(t.doc, beginSignal.key.pos.getEnd(((RecurrentNode) beginSignal.key.n).direction));
                Conflicts.add(t, this, no, beginSignal.key.o);

                if (beginSignal.key.n instanceof ClockNode) {
                    addNextSegment(t, beginSignal.key.pos.getEnd(((RecurrentNode) beginSignal.key.n).direction), inputAct.key.rid + 1, beginSignal.key.o, inputAct.newOption, inputAct, beginSignal);
                }
            }

            addNextSegment(t, inputAct.key.pos.getEnd(direction), inputAct.key.rid, inputAct.key.o, Option.add(t.doc, true, inputAct.newOption, no), inputAct, beginSignal);
        }
    }


    private void addNextSegment(Iteration t, int signalPos, int rid, Option o, Option no, Activation... inputs) {
        Signal ns = getSignal(t.doc, signalPos, o);

        Range r = Range.create(t.doc, signalPos, ns.pos);

        SortedSet<Activation> inputsSet = SetUtils.asSortedSet(inputs);
        inputsSet.addAll(ns.acts);

        Node.addActivationAndPropagate(t, true, new Activation.Key(this, r, rid, o, 0), r, no, inputsSet, inputsSet);
    }


    private void removeActivationCTNode(Iteration t, Activation inputAct) {
        for(Activation act: inputAct.outputs) {
            if(act.key.n == this) {
                removeActivationAndPropagate(t, true, act, act.key.pos);
            }
        }
    }


    public void addActivation(Iteration t, Activation inputAct, Range addedRange) {
        assert addedRange.compareTo(inputAct.key.pos) == 0;

        if(inputAct.key.n instanceof ClockTerminationNode) {
            addActivationCTNode(t, inputAct);
        } else {
            addActivationFromClockOrTerminationNode(t, inputAct);
        }
    }


    public void removeActivation(Iteration t, Activation inputAct, Range removedRange) {
        assert removedRange.compareTo(inputAct.key.pos) == 0;

        if(inputAct.key.n instanceof ClockTerminationNode) {
            removeActivationCTNode(t, inputAct);
        } else {
            removeActivationFromClockOrTerminationNode(t, inputAct);
        }
    }


    private Activation getPreviousAct(Activation act) {
        for(Activation pAct: act.inputs) {
            if(pAct.key.n == this) return pAct;
        }
        return null;
    }


    private Activation getBeginSignal(Activation act) {
        assert act.key.n == this;

        for(Activation nAct: act.inputs) {
            if((nAct.key.n instanceof ClockNode || nAct.key.n instanceof TerminationNode) && (nAct.key.pos.getEnd(((RecurrentNode) nAct.key.n).direction) == act.key.pos.getBegin(direction))) return nAct;
        }
        return null;
    }


    private Activation getEndSignal(Activation act) {
        assert act.key.n == this;

        for(Activation nAct: act.inputs) {
            if((nAct.key.n instanceof ClockNode || nAct.key.n instanceof TerminationNode) && (nAct.key.pos.getEnd(((RecurrentNode) nAct.key.n).direction) == act.key.pos.getEnd(direction))) return nAct;
        }
        return null;
    }


    private Activation getTerminationSignal(Activation act) {
        assert act.key.n == this;

        for(Activation nAct: act.inputs) {
            if((nAct.key.n instanceof TerminationNode) && (nAct.key.pos.getEnd(((RecurrentNode) nAct.key.n).direction) == act.key.pos.getEnd(direction)) && act.key.o.contains(nAct.key.o)) return nAct;
        }
        return null;
    }


    public static class Signal {
        int pos;
        Set<Activation> acts = new TreeSet<>();
    }


    private Signal getSignal(Document doc, int currentPos, Option o) {
        if(currentPos < 0 || currentPos >= doc.length()) return null;

        Signal s = null;

        for(RecurrentNode n: new RecurrentNode[] {clockNode, terminationNode}) {
            Range.Relation rr = n.direction ? (direction ? Range.Relation.BEGIN_BEFORE : Range.Relation.BEGIN_AFTER) : (direction ? Range.Relation.END_BEFORE : Range.Relation.END_AFTER);

            for (Activation act : Activation.select(n, null, Range.create(doc, currentPos - 1, currentPos), rr, null, null, null, null, null, false)) {
                if(act.key.o.contains(o) || o.contains(act.key.o)) {
                    if(s == null || s.pos > act.key.pos.getEnd(((RecurrentNode) act.key.n).direction)) {
                        s = new Signal();
                        s.pos = act.key.pos.getEnd(((RecurrentNode) act.key.n).direction);
                        s.acts.add(act);
                    } else if(s.pos == act.key.pos.getEnd(((RecurrentNode) act.key.n).direction)) {
                        s.acts.add(act);
                    }
                    break;
                }
            }
        }

        if(s == null) {
            s = new Signal();
            s.pos = direction ? 0 : doc.length();
        }
        return s;
    }


    public void propagateAddedActivation(Iteration t, Activation act, Range addedRange, Option conflict) {
        addActivationCTNode(t, act);

        for(RecurrentNode rn: outputNodesWithinDocument.values()) {
            rn.addActivation(t, act, addedRange);
        }
    }


    public void propagateRemovedActivation(Iteration t, Activation act, Range removedRange) {
        removeActivationCTNode(t, act);

        for(RecurrentNode rn: outputNodesWithinDocument.values()) {
            rn.removeActivation(t, act, removedRange);
        }
    }


    public static void addInitialActivations(Iteration t) {
        for(ClockTerminationNode n: t.m.clockTerminationNodes) {
            Activation.Key ak = new Activation.Key(n, Range.create(t.doc, 0, t.doc.length()), 0, t.doc.bottom, 0);

            Node.addActivationAndPropagate(t, false, ak, ak.pos, null, Collections.EMPTY_SET, Collections.EMPTY_SET);
        }
    }


    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("CT[");

        sb.append(clockNode.logicToString());
        sb.append(",");
        sb.append(terminationNode.logicToString());

        sb.append("]");
        return sb.toString();
    }
}
