package org.aika.network.neuron;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Document;
import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.neuron.simple.lattice.InputNode;

import java.util.*;

/**
 *
 * @author Lukas Molzberger
 */
public class Activation implements Comparable<Activation> {

    public static double NEURON_WEIGHT = 0.4;
    public static double FORWARD_WEIGHT = 0.3;
    public static double BACKWARD_WEIGHT = 0.3;


    public final Key key;

    public boolean isRemoved;
    public int removedId;
    public static int removedIdCounter = 1;

    public double weight;

    public Option initialOption;
    public Option newOption;
    public Option completeOption;

    public Map<Key, Activation> shadows = new TreeMap<>();
    public Map<Key, Activation> shadowedBy = new TreeMap<>();

    public TreeSet<Activation> inputs = new TreeSet<>();
    public TreeSet<Activation> directInputs = new TreeSet<>();
    public Set<Activation> outputs = new TreeSet<>();;
    public TreeSet<Activation> directOutputs = new TreeSet<>();


    public Activation(Key key) {
        this.key = key;
    }


/*    public Activation(int pos, int rid, Option o, int fired) {
        key = new Key(pos, rid, o, fired);
    }
*/

    public Activation(Node n, Range pos, int rid, Option o, int fired) {
        key = new Key(n, pos, rid, o, fired);
    }


    public Activation(Node n, Range pos, int rid, Option o, int fired, int id) {
        key = new Key(n, pos, rid, o, fired, id);
    }



    public boolean shadows(Activation act) {
        return act.completeOption.contains(completeOption) && (key.fired <= act.key.fired) && !((act.completeOption.length == completeOption.length && key.fired == act.key.fired && key.id > act.key.id));
    }


    public void addShadows(Activation act) {
        shadows.put(act.key, act);
        act.shadowedBy.put(key, this);
    }


    public List<int[]> getShadowedRange() {
        List<List<int[]>> r = new ArrayList<>();
        for(Key shadActKey: shadowedBy.keySet()) {
            r.add(shadActKey.pos.getSegments());
        }
        return Range.union(r);
    }


    public List<int[]> getNonShadowedRange() {
        return Range.complement(key.pos.getSegments(), getShadowedRange());
    }


    public void link(Collection<Activation> inputActs, Set<Activation> directInputActs, Set<Activation> outputActs, Set<Activation> directOutputActs) {
        for(Activation inputAct: inputActs) {
            inputAct.outputs.add(this);
        }
        for(Activation outputAct: outputActs) {
            outputAct.inputs.add(this);
        }
        for(Activation directInputAct: directInputActs) {
            directInputAct.directOutputs.add(this);
        }
        for(Activation directOutputAct: directOutputActs) {
            directOutputAct.directInputs.add(this);
        }

        inputs.addAll(inputActs);
        directInputs.addAll(directInputActs);
        outputs.addAll(outputActs);
        directOutputActs.addAll(directOutputActs);
    }


    public void unlink() {
        for (Activation act : inputs) {
            act.outputs.remove(this);
        }
        for (Activation act : outputs) {
            act.inputs.remove(this);
        }
        for (Activation act : directInputs) {
            act.directOutputs.remove(this);
        }
        for (Activation act : directOutputs) {
            act.directInputs.remove(this);
        }
    }


    public double computeWeight() {

        double forwardsWeight = key.n.computeForwardWeight(this);

        double maxBackwardsWeight = 0.0;

        for (Activation act : outputs) {
            if (maxBackwardsWeight < act.weight) {
                maxBackwardsWeight = act.weight;
            }
        }

        double newWeight = (maxBackwardsWeight * BACKWARD_WEIGHT) + (forwardsWeight * FORWARD_WEIGHT) + (NEURON_WEIGHT * (key.n != null ? key.n.getNodeWeight(this) : 0.0));
        double delta = newWeight - weight;
        weight = newWeight;

        return delta;
    }


    public void register(Iteration t) {
        if(key.n.activations.isEmpty()) {
            t.hasActivations.add(key.n);
        }
        key.n.activations.put(key, this);

        if(key.o.activations == null) {
            key.o.activations = new TreeMap<>();
        }
        key.o.activations.put(key, this);

        if(completeOption.activationsComplete == null) {
            completeOption.activationsComplete = new TreeSet<>();
        }
        completeOption.activationsComplete.add(this);

        if(key.n instanceof InputNode) {
            key.pos.inputActivations.put(key, this);
        }

        key.pos.activations.put(key, this);
    }


    public void unregister(Iteration t) {
        assert !completeOption.activationsComplete.isEmpty();

        key.n.activations.remove(key);

        if(key.n.activations.isEmpty()) {
            t.hasActivations.remove(key.n);

            key.n.clearActivations();
        }

        key.o.activations.remove(key);
        completeOption.activationsComplete.remove(this);

        if(key.n instanceof InputNode) {
            key.pos.inputActivations.remove(key);
        }

        key.pos.activations.remove(key);
    }


    public double computeAverageInputWeight() {
        if(inputs.isEmpty()) return 0.0;

        double avgWeight = 0.0;
        for(Activation uAct: inputs) {
            avgWeight += uAct.weight;
        }

        return avgWeight / inputs.size();
    }


    public static Activation get(Node n, Range r, Option o) {
        return get(n, 0, r, Range.Relation.OVERLAPS, o, Option.Relation.EQUALS, null, null, false);
    }


    public static Activation get(Node n, Integer rid, Range r, Range.Relation rr, Option o, Option.Relation or, Integer fired, Integer id, boolean includeShadowed) {
        for(Activation act: select(n, rid, r, rr, o, or, null, fired, id, includeShadowed)) {
            return act;
        }
        return null;
    }

    // TODO: check id
    public static Activation get(Node n, Key ak) {
        return get(n, null, ak.pos, Range.Relation.CONTAINS, ak.o, Option.Relation.EQUALS, ak.fired, null, true);
    }


    public static List<Activation> select(Node n, Integer rid, Range r, Range.Relation rr, Option o, Option.Relation or, Option no, Integer fired, Integer id, boolean includeShadowed) {
        Key bk = new Key(n != null ? n : Node.MIN_NODE, rr == Range.Relation.EQUALS ? r : Range.MIN, rid != null ? rid : Integer.MIN_VALUE, or == Option.Relation.EQUALS ? o : Option.MIN, fired != null ? fired : Integer.MIN_VALUE, id != null ? id : Integer.MIN_VALUE);
        Key ek = new Key(n != null ? n : Node.MAX_NODE, rr == Range.Relation.EQUALS ? r : Range.MAX, rid != null ? rid : Integer.MAX_VALUE, or == Option.Relation.EQUALS ? o : Option.MAX, fired != null ? fired : Integer.MAX_VALUE, id != null ? id : Integer.MAX_VALUE);

        List<Activation> results = new ArrayList<>();

        if(r != null && rr.supportsCollect && rr.estimateNodeCount(r) < n.activations.size()) { // TODO: estimateNodeCount for Options
            for(Range rn: r.select(rr)) {
                for(Activation act: rn.activations.subMap(bk, true, ek, true).values()) {
                    if(act.filter(n, rid, r, rr, o, or, no, fired, id, includeShadowed)) {
                        results.add(act);
                    }
                }
            }
        } else if(n != null) {
            for(Activation act: n.activations.subMap(bk, true, ek, true).values()) {
                if(act.filter(n, rid, r, rr, o, or, no, fired, id, includeShadowed)) {
                    results.add(act);
                }
            }
        } else if(o != null) {
            for(Option on: o.select(or)) {
                for(Activation act: on.activations.subMap(bk, true, ek, true).values()) {
                    if(act.filter(n, rid, r, rr, o, or, no, fired, id, includeShadowed)) {
                        results.add(act);
                    }
                }
            }
        } else {
            assert false;
        }

        return results;
    }


    private boolean filter(Node n, Integer rid, Range r, Range.Relation rr, Option o, Option.Relation or, Option no, Integer fired, Integer id, boolean includeShadowed) {
        List<int[]> cr = r != null ? (includeShadowed ? key.pos.getSegments() : getNonShadowedRange()) : null;
        return (n != null || key.n == n) && (rid == null || key.rid == rid) && (r == null || (!cr.isEmpty() && rr.compare(cr, r.getSegments()))) && (o == null || or.compare(key.o, o)) && (no == null || no == newOption) && (fired == null || key.fired == fired) && (id == null || key.id == id);
    }


    public String toString(Document doc) {
        StringBuilder sb = new StringBuilder();
        sb.append("<ACT ");
        sb.append(",(");
        sb.append(key.pos);
        sb.append("),");
        sb.append(doc.getContent().substring(Math.max(0, key.pos.getBegin() - 3), Math.min(doc.length(), key.pos.getEnd() + 3)));
        sb.append(",");
        sb.append(key.n);
        sb.append(">");
        return sb.toString();
    }


    @Override
    public int compareTo(Activation act) {
        if(key.n == null && act.key.n != null) return -1;
        if(key.n != null && act.key.n == null) return 1;

        if(key.n != null && act.key.n != null) {
            int r = key.n.compareTo(act.key.n);
            if(r != 0) return r;
        }
        return key.compareTo(act.key);
    }


    public static final class Key implements Comparable<Key> {
        public final Node n;
        public final Range pos;
        public final int rid;
        public final Option o;
        public final int fired;

        /**
         * For a NegativeInputNode this id field contains the id of the second input node of the AndNode that is using this negation.
         * For an OrNode the id field contains the id of the input node that caused the current OrNode activation.
         */
        public final int id;
        private int refCount = 0;


        public Key(Node n, Range pos, int rid, Option o, int fired) {
            this(n, pos, rid, o, fired, 0);
        }


        public Key(Node n, Range pos, int rid, Option o, int fired, int id) {
            assert pos == null || pos.isGapLess();
            this.n = n;
            this.pos = pos;
            this.rid = rid;
            this.o = o;
            this.fired = fired;
            this.id = id;
            countRef();
            if(o != null) {
                o.countRef();
            }
        }


        public void countRef() {
            refCount++;
        }


        public void releaseRef() {
            assert refCount > 0;
            refCount--;
            if(refCount == 0) {
                o.releaseRef();
            }
        }


        @Override
        public int compareTo(Key k) {
            int r = n.compareTo(k.n);
            if(r != 0) return r;
            r = pos.compareTo(k.pos);
            if(r != 0) return r;
            r = Integer.compare(rid, k.rid);
            if(r != 0) return r;
            r = o.compareTo(k.o);
            if(r != 0) return r;
            r = Integer.compare(fired, k.fired);
            if(r != 0) return r;
            return Integer.compare(id, k.id);
        }


        public String toString() {
            return pos + " " + rid + " " + o + " " + fired + " " + id;
        }
    }
}
