package org.aika.network.neuron.recurrent;


/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Node;

import java.util.*;


/**
 *
 * @author Lukas Molzberger
 */
public class OutputNode extends RecurrentNode {

    public InputNode inputNode;
    public ClockTerminationNode ctNode;

    public int maxLength;


    public OutputNode(Model m, int maxLength) {
        super(m, 2);
        this.maxLength = maxLength;
    }


    @Override
    public double computeForwardWeight(Activation act) {
        for(Activation inputAct: act.inputs) {
            if(!(inputAct.key.n instanceof ClockTerminationNode)) {
                return inputAct.key.n.computeForwardWeight(inputAct);
            }
        }
        assert false;
        return 0;
    }


    @Override
    public double getNodeWeight(Activation act) {
        return weight;
    }


    public Collection<RecurrentNode> getChildren() {
        return null;
    }


    @Override
    public void addActivation(Iteration t, Activation inputAct, Range addedRange) {
        if(inputAct.key.n == inputNode) {
            int end = inputAct.key.pos.getEnd(inputNode.direction);
            Range r = Range.create(t.doc, end - 1, end);
            for(Activation ctAct: Activation.select(ctNode, null, r, Range.Relation.CONTAINS, null, null, null, null, null, false)) {
                Option o = Option.add(t.doc, true, ctAct.completeOption, inputAct.key.o);
                if(o != null) {
                    Node.addActivationAndPropagate(t, false, new Activation.Key(this, ctAct.key.pos, 0, o, Math.max(ctAct.key.fired, inputAct.key.fired)), ctAct.key.pos, null, prepareIActs(inputAct, ctAct), prepareDIActs(inputAct, ctAct));
                }
            }
        } else if(inputAct.key.n == ctNode) {
            Activation inputNodeAct = Activation.get(inputNode, null, inputAct.key.pos, Range.Relation.CONTAINED_IN, null, null, null, null, false);
            if(inputNodeAct != null) {
                Option o = Option.add(t.doc, true, inputNodeAct.key.o, inputAct.completeOption);
                if(o != null) {
                    Node.addActivationAndPropagate(t, false, new Activation.Key(this, inputAct.key.pos, 0, o, Math.max(inputNodeAct.key.fired, inputAct.key.fired)), inputAct.key.pos, null, prepareIActs(inputAct, inputNodeAct), prepareDIActs(inputAct, inputNodeAct));
                }
            }

            Activation prevCtAct = getPreviousCtActivation(inputAct);
            if(prevCtAct != null) {
                Activation prevOutAct = getOutputActivation(prevCtAct);

                if(prevOutAct != null) {
                    int nRid = prevOutAct.key.rid + (inputAct.key.rid - prevCtAct.key.rid);

                    if(nRid < maxLength) {
                        Option o = Option.add(t.doc, true, prevOutAct.key.o, inputAct.completeOption);
                        if (o != null) {
                            Node.addActivationAndPropagate(t, false, new Activation.Key(this, inputAct.key.pos, nRid, o, Math.max(prevOutAct.key.fired, inputAct.key.fired)), inputAct.key.pos, null, prepareIActs(inputAct, prevOutAct), prepareDIActs(inputAct, prevOutAct));
                        }
                    }
                }
            }
        } else if(inputAct.key.n == this) {
            Activation currentCtAct = getCtActivation(inputAct);
            for(Activation nextCtAct: getNextCtActivations(currentCtAct)) {
                int nRid = inputAct.key.rid + (nextCtAct.key.rid - currentCtAct.key.rid);

                if(nRid < maxLength) {
                    Option o = Option.add(t.doc, true, nextCtAct.completeOption, inputAct.key.o);
                    if (o != null) {
                        Node.addActivationAndPropagate(t, true, new Activation.Key(this, nextCtAct.key.pos, nRid, o, Math.max(nextCtAct.key.fired, inputAct.key.fired)), nextCtAct.key.pos, null, prepareIActs(inputAct, nextCtAct), prepareDIActs(inputAct, nextCtAct));
                    }
                }
            }
        }
    }


    @Override
    public void removeActivation(Iteration t, Activation inputAct, Range removedRange) {
        Activation act = getOutputActivation(inputAct);
        if(act != null) {
            Node.removeActivationAndPropagate(t, true, act, act.key.pos);
        }
    }


    private static Set<Activation> prepareIActs(Activation actA, Activation actB) {
        Set<Activation> iActs = new TreeSet<>();
        iActs.addAll(actA.inputs);
        iActs.addAll(actB.inputs);
        return iActs;
    }


    private static Set<Activation> prepareDIActs(Activation actA, Activation actB) {
        Set<Activation> diActs = new TreeSet<>();
        diActs.add(actA);
        diActs.add(actB);
        return diActs;
    }


    private Activation getPreviousCtActivation(Activation act) {
        for(Activation ctAct: act.directInputs) {
            if(ctAct.key.n == ctNode) {
                return ctAct;
            }
        }
        return null;
    }


    private Activation getCtActivation(Activation act) {
        for(Activation ctAct: act.directInputs) {
            if(ctAct.key.n == ctNode) {
                return ctAct;
            }
        }
        return null;
    }


    private Activation getOutputActivation(Activation act) {
        for(Activation outAct: act.directOutputs) {
            if(outAct.key.n == this) {
                return outAct;
            }
        }
        return null;
    }


    private List<Activation> getNextCtActivations(Activation act) {
        ArrayList<Activation> results = new ArrayList<>();
        for(Activation nextCtAct: act.directOutputs) {
            if(nextCtAct.key.n == ctNode) {
                results.add(nextCtAct);
            }
        }
        return results;
    }


    public void propagateAddedActivation(Iteration t, Activation act, Range addedRange, Option conflict) {
        if(conflict != null) return;

        addActivation(t, act, act.key.pos);
        if(neuron != null) {
            neuron.propagateAddedActivation(t, act, addedRange);
        }
    }


    public void propagateRemovedActivation(Iteration t, Activation act, Range removedRange) {
        removeActivation(t, act, removedRange);
        if(neuron != null) {
            neuron.propagateRemovedActivation(t, act, removedRange);
        }
    }


    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("RN[");

        sb.append(inputNode.logicToString());
        sb.append(",");
        sb.append(ctNode.logicToString());

        sb.append("]");
        return sb.toString();
    }
}
