package org.aika.network.neuron.simple.lattice;


/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Activation.Key;
import org.aika.network.neuron.Neuron;
import org.aika.network.neuron.Node;
import org.aika.network.neuron.Synapse;
import org.aika.network.neuron.simple.SimpleNeuron;
import org.aika.network.neuron.simple.lattice.AndNode.Refinement;

import java.util.*;

/**
 *
 * @author Lukas Molzberger
 */
public abstract class LatticeNode extends Node {

    public TreeMap<Refinement, AndNode> andChildrenWithinDocument = new TreeMap<>();
    public TreeMap<Refinement, AndNode> andChildren = new TreeMap<>();
    public TreeSet<OrNode> orChildren = new TreeSet<>();

    public long visitedPropagateSignificance = -1;
    public long visitedCollectPublicNodes = -1;
    public long visitedAllowedOption = -1;


    public LatticeNode(Model m, int level) {
        super(m, level);
    }


    public abstract void expandToNextLevel(Iteration t, Activation act, Range addedRange, Option conflict, boolean train);


    protected abstract void collectNodeAndRefinements(Refinement newRef, Set<Refinement> inputs);


    public boolean isPublic() {
        return neuron != null || !orChildren.isEmpty();
    }


    public void collectPublicNodes(List<LatticeNode> results, long v) {
        if(visitedCollectPublicNodes == v) return;
        visitedCollectPublicNodes = v;

        if(isPublic()) results.add(this);

        for(AndNode n: andChildrenWithinDocument.values()) {
            n.collectPublicNodes(results, v);
        }
    }


    public void train(Iteration t) {
        if(isFrequentOrPredefined(frequency)) {
            for(Activation act: activations.values()) {
                expandToNextLevel(t, act, act.key.pos, null, true);
            }
        }
    }


    public void propagateAddedActivation(Iteration t, Activation act, Range addedRange, Option conflict) {
        super.propagateAddedActivation(t, act, addedRange, conflict);

        expandToNextLevel(t, act, addedRange, conflict, false);
    }


    public void propagateRemovedActivation(Iteration t, Activation act, Range removedRange) {
        super.propagateRemovedActivation(t, act, removedRange);

        removeFromNextLevel(t, act, removedRange);
    }


    public Set<AndNode> getAndChildPatterns(Set<AndNode> results) {
        for(AndNode cp: andChildren.values()) {
            results.add(cp);
            cp.getAndChildPatterns(results);
        }

        return results;
    }


    public boolean computeAndParents(int offset, SortedSet<Refinement> inputs, Map<Refinement, LatticeNode> parents, Set<RSKey> visited) {
        RSKey v = new RSKey(this, offset);
        if(visited.contains(v)) {
            return true;
        }
        visited.add(v);

        if(inputs.size() == 1) {
            parents.put(inputs.first(), this);
            return true;
        }

        for(Refinement ref: inputs) {
            int nOffset = Math.min(offset, ref.getRelativePosition());
            SortedSet<Refinement> childInputs = new TreeSet<>(inputs);
            childInputs.remove(ref);

            LatticeNode cp = andChildren.get(new Refinement(nOffset < offset ? nOffset - offset : ref.getRelativePosition() - nOffset, ref.inferenceMode, ref.input));
            if(cp == null) {
                return false;
            }
            if(!cp.isFrequentOrPredefined()) {
                return false;
            }

            if(!cp.computeAndParents(nOffset, childInputs, parents, visited)) {
                return false;
            }
        }
        return true;
    }


    public void removeFromNextLevel(Iteration t, Activation iAct, Range removedRange) {
        Key ak = iAct.key;
        for(AndNode c: andChildrenWithinDocument.values()) {
            c.removeActivation(t, ak, removedRange);
        }
        for(OrNode c: orChildren) {
            Activation act = Activation.get(c, new Key(this, ak.pos, ak.rid, iAct.newOption != null ? Option.add(t.doc, false, ak.o, iAct.newOption) : ak.o, ak.fired, id));
            c.removeActivation(t, act, removedRange);
        }
    }


    public void remove(Model m) {
        assert !isRemoved;

        if(neuron != null) {
            neuron.remove();
        }

        while(!andChildren.isEmpty()) {
            andChildren.pollFirstEntry().getValue().remove(m);
        }

        while(!orChildren.isEmpty())  {
            orChildren.pollFirst().remove(m);
        }

        m.allNodes.remove(this);

        clearActivations();

        isRemoved = true;
        isRemovedId = isRemovedIdCounter++;
    }


    public static SimpleNeuron addNeuron(SimpleNeuron n, Set<Synapse> inputs) {
        SortedSet<LatticeNode> outputs = new TreeSet<>();

        if(inputs.isEmpty()) {
            InputNode node = InputNode.add(n.m, null, null, false);
            node.isPredefined = true;
            outputs.add(node);
        } else {
            Map<RSKey, Set<Synapse>> firstLevel = computePredefinedInputNodes(n, outputs, inputs);
            Map<RSKey, Set<Synapse>> nextLevel = !firstLevel.isEmpty() ? computePredefinedSecondLevelAndNodes(n, outputs, firstLevel) : null;

            while (nextLevel != null && !nextLevel.isEmpty()) {
                nextLevel = computePredefinedAndNodes(n, outputs, nextLevel);
            }
        }

        assert !outputs.isEmpty();

        Node outputNode = null;
        if(outputs.size() == 1) {
            outputNode = outputs.first();
        } else if(outputs.size() > 1) {
            OrNode orNode = new OrNode(n.m, -1);
            for(LatticeNode on: outputs) {
                orNode.addInput(on);
            }
            outputNode = orNode;
        }
        assert outputNode != null;
        assert n.node == null;
        n.node = outputNode;
        if(outputNode.neuron == null) {
            outputNode.neuron = n;
        }
        outputNode.isPredefined = true;
        return (SimpleNeuron) outputNode.neuron;
    }



    public static class RSKey implements Comparable<RSKey> {
        LatticeNode pa;
        int offset;

        public RSKey(LatticeNode pa, int offset) {
            this.pa = pa;
            this.offset = offset;
        }


        public String toString() {
            return "Offset:" + offset;
        }

        @Override
        public int compareTo(RSKey rs) {
            int r = pa.compareTo(rs.pa);
            if(r != 0) return r;
            return Integer.compare(offset, rs.offset);
        }
    }


    public static Map<RSKey, Set<Synapse>> computePredefinedInputNodes(Neuron n, Set<LatticeNode> outputs, Set<Synapse> inputs) {
        Map<RSKey, Set<Synapse>> results = new TreeMap<>();
        for(Synapse s: inputs) {
            InputNode in = InputNode.add(n.m, !s.relative ? s.rid : null, s.input, s.w < 0);
            in.isPredefined = true;

            prepareResultsForPredefinedNodes(results, outputs, in, n, s, s.rid, inputs);
        }

        return results;
    }


    public static Map<RSKey, Set<Synapse>> computePredefinedSecondLevelAndNodes(Neuron n, Set<LatticeNode> outputs, Map<RSKey, Set<Synapse>> previousLevel) {
        Map<RSKey, Set<Synapse>> results = new TreeMap<>();

        for(Map.Entry<RSKey, Set<Synapse>> me: previousLevel.entrySet()) {
            for(Synapse s: me.getValue()) {
                InputNode pb = s.getInputNode();
                Refinement refa = new Refinement(me.getKey().offset - s.rid, n.inferenceMode, (InputNode) me.getKey().pa);
                Refinement refb = new Refinement(s.rid - me.getKey().offset, n.inferenceMode, pb);

                int nOffset = Math.min(s.rid, me.getKey().offset);

                if(pb.computeSynapseWeightSum(n) <= 0) {
                    AndNode sln = me.getKey().pa.andChildren.get(refb);
                    if (sln == null) {
                        SortedMap<Refinement, LatticeNode> parents = new TreeMap<>();
                        parents.put(refb, me.getKey().pa);
                        parents.put(refa, pb);
                        sln = new AndNode(n.m, 2, parents, n.inferenceMode);
                    }

                    prepareResultsForPredefinedNodes(results, outputs, sln, n, s, nOffset, me.getValue());
                }
            }
        }
        return results;
    }


    public static Map<RSKey, Set<Synapse>> computePredefinedAndNodes(Neuron n, Set<LatticeNode> outputs, Map<RSKey, Set<Synapse>> previousLevel) {
        Map<RSKey, Set<Synapse>> results = new TreeMap<>();

        for(Map.Entry<RSKey, Set<Synapse>> me: previousLevel.entrySet()) {
            AndNode pa = (AndNode) me.getKey().pa;
            for(Map.Entry<Refinement, LatticeNode> mea: pa.parents.entrySet()) {
                int pOffset = me.getKey().offset - mea.getKey().getOffset();
                for(Synapse s: me.getValue()) {
                    int nOffset = Math.min(s.rid, me.getKey().offset);
                    Refinement pRef = new Refinement(s.rid - pOffset, n.inferenceMode, s.getInputNode());
                    Refinement ref = new Refinement(s.rid - me.getKey().offset, n.inferenceMode, s.getInputNode());
                    AndNode pb = mea.getValue().andChildren.get(pRef);

                    if(pb != null && pb.computeSynapseWeightSum(n) <= 0) {
                        AndNode nln = pa.andChildren.get(ref);
                        if (nln == null) {
                            Set<Refinement> inputs = new TreeSet<>();

                            pa.collectNodeAndRefinements(ref, inputs);
                            nln = new AndNode(n.m, pa.level + 1, AndNode.computeParents(inputs), n.inferenceMode);
                        }

                        prepareResultsForPredefinedNodes(results, outputs, nln, n, s, nOffset, me.getValue());
                    }
                }
            }
        }
        return results;
    }


    private static void prepareResultsForPredefinedNodes(Map<RSKey, Set<Synapse>> results, Set<LatticeNode> outputs, LatticeNode in, Neuron n, Synapse s, int offset, Set<Synapse> inputs) {
        in.isPredefined = true;
        if(in.computeSynapseWeightSum(n) > 0) {
            outputs.add(in);
        } else {
            RSKey rs = new RSKey(in, offset);
            Set<Synapse> nInputs = new TreeSet<>(inputs);
            nInputs.remove(s);
            rs.offset = offset;
            results.put(rs, nInputs);
        }
    }
}
