package org.aika.network.neuron.simple.lattice;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Conflicts;
import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Activation.Key;
import org.aika.network.neuron.Neuron;
import org.aika.network.neuron.Node;
import org.aika.network.neuron.Synapse;
import org.aika.network.neuron.simple.SimpleNeuron;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.linear.*;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;

import java.util.*;

/**
 *
 * @author Lukas Molzberger
 */
public class AndNode extends LogicNode {

    public double minPRelevance = 0.0;

    public SortedMap<Refinement, LatticeNode> parents = new TreeMap<>();

    public Neuron publishedPatternNeuron = null;
    public boolean isSignificant;
    public AndNode directSignificantAncestor = null;
    public SortedSet<AndNode> significantAncestors;

    public boolean inferenceMode;
    public boolean shouldBePublished = false;


    public AndNode(Model m, int level, SortedMap<Refinement, LatticeNode> parents, boolean inferenceMode) {
        super(m, level);
        this.parents = parents;
        this.inferenceMode = inferenceMode;
        for(Map.Entry<Refinement, LatticeNode> me: parents.entrySet()) {
            me.getValue().andChildren.put(me.getKey(), this);
        }
    }


    @Override
    public boolean isAllowedOption(Option n, Activation act, long v) {
        if(visitedAllowedOption == v) return false;
        visitedAllowedOption = v;

        if(act.initialOption != null && n.contains(act.initialOption)) return true;

        for(Activation pAct: act.directInputs) {
            if(pAct.key.n.isAllowedOption(n, pAct, v)) return true;
        }
        return false;
    }


    public boolean isNegative() {
        for(LatticeNode n: parents.values()) {
            if(!(n instanceof NegativeInputNode)) return false;
        }
        return true;
    }


    @Override
    public boolean containsNegative() {
        for(Refinement ref: parents.keySet()) {
            if(ref.input instanceof NegativeInputNode) return true;
        }
        return false;
    }


    @Override
    public double computeForwardWeight(Activation act) {
        if(!act.inputs.isEmpty()) {
            return act.computeAverageInputWeight();
        }

        return 0.0;
    }


    public double getNodeWeight(Activation act) {
        return weight;
    }


    public void addActivation(Iteration t, Key ak, Range addedRange, Option initOption, Set<Activation> inputActs, Set<Activation> directInputActs) {
        Node.addActivationAndPropagate(t, false, ak, addedRange, initOption, inputActs, directInputActs);
    }


    protected void removeActivation(Iteration t, Key ak, Range removedRange) {
        for(Activation act: Activation.select(this, 0, ak.pos.intersection(removedRange), Range.Relation.OVERLAPS, ak.o, Option.Relation.CONTAINS, null, null, null, true)) {
            Node.removeActivationAndPropagate(t, false, act, act.key.pos.intersection(removedRange));
        }
    }


    public void computeWeight(Model m) {
        if(m.numberOfPositions == 0 || frequency < Node.minFrequency) {
            return;
        }

        double nullHyp = 1.0;
        for(Refinement ref: parents.keySet()) {
            Node in = ref.input.inputNeuron.node;
            double p = (double) in.frequency / (double) m.numberOfPositions;
            if(p > 1.0) p = 1.0;
            nullHyp *= p;
        }

        BinomialDistribution binDist = new BinomialDistribution(null, m.numberOfPositions, nullHyp);

        weight = binDist.cumulativeProbability(frequency - 1);

        n = m.numberOfPositions;

        if(level > 1) {
            minPRelevance = 1.0;
            for(LatticeNode n: parents.values()) {
                double p = 1.0 - ((double) frequency / (double) n.frequency);
                if(minPRelevance > p) minPRelevance = p;
            }
        }

        setSignificant(weight > 0.99);
    }


    public Map<Node, Double> computeMinPRel() {
        TreeMap<Node, Double> result = new TreeMap<>();
        for(LatticeNode n: parents.values()) {
            double p = 1.0 - ((double) frequency / (double) n.frequency);
            result.put(n, p);
        }
        return result;
    }


    public void setSignificant(boolean sig) {
        if (isSignificant != sig) {
            isSignificant = sig;
            propagateSignificance(Node.visitedCounter++);
        }
    }


    private void collectCoveredSignificantAncestors(Set<AndNode> results) {
        for(AndNode sa: significantAncestors) {
            results.add(sa);
            sa.collectCoveredSignificantAncestors(results);
        }
    }


    private TreeSet<AndNode> computeSignificantAncestors() {
        TreeSet<AndNode> sa = new TreeSet<>();
        Set<AndNode> coveredSignificantAncestors = new TreeSet<>();
        for(AndNode cp: andChildren.values()) {
            if(cp.isSignificant && cp.directSignificantAncestor != null) {
                sa.add(cp.directSignificantAncestor);
                cp.directSignificantAncestor.collectCoveredSignificantAncestors(coveredSignificantAncestors);
            }
        }
        sa.removeAll(coveredSignificantAncestors);
        return sa;
    }


    public void propagateSignificance(long v) {
        if(visitedPropagateSignificance == v) return;
        visitedPropagateSignificance = v;

        if(isSignificant) {
            significantAncestors = computeSignificantAncestors();

            if(significantAncestors.size() == 1) {
                if(directSignificantAncestor == this) {
//                    unpublish();
                    shouldBePublished = false;
                }
                directSignificantAncestor = significantAncestors.first();
            } else {
                if(directSignificantAncestor != this) {
                    shouldBePublished = true;
//                    publish();
                }
                directSignificantAncestor = this;
            }
        } else {
            if(directSignificantAncestor == this) {
                shouldBePublished = false;
//                unpublish();
            }
            directSignificantAncestor = null;
        }
        for(Node pn: parents.values()) {
            if(pn instanceof AndNode) {
                AndNode apn = (AndNode) pn;
                if (apn.isSignificant) {
                    apn.propagateSignificance(v);
                }
            }
        }
    }



    public void publish(Model m) {
        if(!isPredefined && publishedPatternNeuron == null) {
            TreeSet<AndNode> significantLowerBound = new TreeSet<>();
            collectSignificantLower(significantLowerBound, new TreeSet<>(), Collections.singleton(this));
            TreeSet<Node> nonSignificantUpperBound = computeNonSignificantUpperBound(significantLowerBound);

            // TODO
//            publishedPatternNeuron = computeNeuron(m, significantLowerBound, nonSignificantUpperBound);
        }
    }


    public void unpublish() {
        if(publishedPatternNeuron != null) {
            publishedPatternNeuron.unpublish();
        }
    }


    private Neuron computeNeuron(Model m, TreeSet<AndNode> significantLowerBound, TreeSet<Node> nonSignificantUpperBound) {
        Map<Refinement, Integer> indexes = new TreeMap<>();
        ArrayList<Refinement> revIndexes = new ArrayList<>();
        for(AndNode n: significantLowerBound) {
            for(Refinement ref: n.parents.keySet()) {
                if(!indexes.containsKey(ref)) {
                    indexes.put(ref, indexes.size());
                    revIndexes.add(ref);
                }
            }
        }

        double[] objF = new double[indexes.size() + 1];
        objF[0] = 1;
        LinearObjectiveFunction f = new LinearObjectiveFunction(objF, 0 );

        LinearConstraint[] constraintArr = new LinearConstraint[significantLowerBound.size() + nonSignificantUpperBound.size()];
        int i = 0;
        for(AndNode n: significantLowerBound) {
            double[] c = new double[indexes.size() + 1];
            c[0] = 1;
            for(Refinement ref: n.parents.keySet()) {
                c[indexes.get(ref) + 1] = ref.input.getSign();
            }
            constraintArr[i] = new LinearConstraint(c, Relationship.GEQ, 0.5);
            i++;
        }
        for(Node n: nonSignificantUpperBound) {
            double[] c = new double[indexes.size() + 1];
            c[0] = 1;

            if(n instanceof InputNode) {
                c[indexes.get(n) + 1] = ((InputNode) n).getSign();
            } else if(n instanceof AndNode) {
                for(Refinement ref: ((AndNode)n).parents.keySet()) {
                    c[indexes.get(ref) + 1] = ref.input.getSign();
                }
            }
            constraintArr[i] = new LinearConstraint(c, Relationship.LEQ, -0.5);
            i++;
        }

        LinearConstraintSet constraints = new LinearConstraintSet(constraintArr);

        SimplexSolver solver = new SimplexSolver();
        PointValuePair solution = solver.optimize(f, constraints, GoalType.MAXIMIZE, new NonNegativeConstraint(false));

        double bias = solution.getKey()[0];
        TreeSet<Synapse> synapses = new TreeSet<>();
        for(int j = 1; j < solution.getKey().length; j++) {
            Refinement ref = revIndexes.get(j - 1);
            Synapse s = new Synapse(ref.input.inputNeuron, ref.rid, ref.input.rid == null);
            s.w = (float) solution.getKey()[j];
            synapses.add(s);
        }
        return SimpleNeuron.create(m, new SimpleNeuron(), bias, synapses, false, false);
    }


    public static void collectSignificantLower(Set<AndNode> significantLowerBound, Set<AndNode> coveredByCurrentLevel, Set<AndNode> currentLevelNodes) {
        Set<AndNode> nextLevelNodes = new TreeSet<>();
        for(AndNode n: currentLevelNodes) {
            if(n.level > 2) {
                for (Node pn : n.parents.values()) {
                    nextLevelNodes.add((AndNode) pn);
                }
            }
        }

        TreeSet<AndNode> coveredByNextLevel = new TreeSet<>();
        if(!nextLevelNodes.isEmpty()) {
            collectSignificantLower(significantLowerBound, coveredByNextLevel, nextLevelNodes);
        }

        for(AndNode cn: coveredByNextLevel) {
            coveredByCurrentLevel.addAll(cn.andChildren.values());
        }

        for(AndNode n: currentLevelNodes) {
            if(n.isSignificant && !coveredByCurrentLevel.contains(n)) {
                significantLowerBound.add(n);
                coveredByCurrentLevel.add(n);
            }
        }
    }


    private static TreeSet<Node> computeNonSignificantUpperBound(TreeSet<AndNode> significantLowerBound) {
        TreeSet<Node> nonSignificantUpperBound = new TreeSet<>();
        for(AndNode n: significantLowerBound) {
            nonSignificantUpperBound.addAll(n.parents.values());
        }

        return nonSignificantUpperBound;
    }


    public String significantAncestorsToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("SA:{");
        boolean first = true;
        if(significantAncestors != null) {
            for (AndNode sa : significantAncestors) {
                if (!first) {
                    sb.append(", ");
                }
                sb.append(sa.id);
                first = false;
            }
        }
        sb.append("}");
        return sb.toString();
    }


    @Override
    public void cleanup(Model m) {
        if(!isRemoved && !isFrequentOrPredefined()) {
            remove(m);
        }
    }


    @Override
    public void expandToNextLevel(Iteration t, Activation act, Range addedRange, Option conflict, boolean train) {

        // Check if the activation has been deleted in the meantime.
        if(act.isRemoved) {
            return;
        }

        for(Map.Entry<Refinement, LatticeNode> mea: parents.entrySet()) {
            LatticeNode pn = mea.getValue();

            for(Map.Entry<Refinement, AndNode> meb: pn.andChildrenWithinDocument.entrySet()) {
                if(meb.getKey().inferenceMode == inferenceMode) {
                    processCandidate(t, this, meb.getValue(), new Refinement(meb.getKey().rid - mea.getKey().getOffset(), inferenceMode, meb.getKey().input), act, addedRange, conflict, train);
                }
            }
        }

        OrNode.processCandidate(t, this, act, addedRange, conflict, train);
    }


    public static void processCandidate(Iteration t, LatticeNode firstNode, Node secondNode, Refinement refinement, Activation act, Range addedRange, Option conflict, boolean train) {
        if(firstNode != secondNode || refinement.rid != 0) {
            if(train) {
                if (firstNode.isFrequentOrPredefined() && !refinement.inferenceMode) {
                    createNextLevelPattern(t, firstNode, refinement);
                }
            } else {
                addActivationsToNextLevelPattern(t, firstNode, secondNode, refinement, act, addedRange, conflict);
            }
        }
    }


    public static void createNextLevelPattern(Iteration t, LatticeNode firstNode, Refinement refinement) {
        if(firstNode.andChildren.containsKey(refinement)) {
            return;
        }

        Set<Refinement> inputs = new TreeSet<>();

        firstNode.collectNodeAndRefinements(refinement, inputs);

        for (Refinement ref : inputs) {
            if (ref.input.isBlocked || ref.input.inputNeuron == null || ref.input.inputNeuron.isBlocked) {
                return;
            }
        }
        SortedMap<Refinement, LatticeNode> nlParents = computeParents(inputs);

        if (nlParents != null) {
            prepareNextLevelPattern(t, firstNode.level + 1, nlParents);
        }
    }


    public static void addActivationsToNextLevelPattern(Iteration t, LatticeNode firstNode, Node secondNode, Refinement refinement, Activation act, Range addedRange, Option conflict) {
        Key ak = act.key;
        AndNode nlp = firstNode.andChildren.get(refinement);
        if(nlp == null) {
            return;
        }

        if(!secondNode.isNegative()) {
            for (Activation secondAct : Activation.select(secondNode, act.key.rid + refinement.rid, addedRange, Range.Relation.OVERLAPS, null, null, null, null, null, false)) {
                Option o = Option.add(t.doc, true, ak.o, secondAct.key.o);
                if (o != null && (conflict == null || o.contains(conflict))) {
                    Set<Activation>[] iActs = prepareInputActs(act, secondAct);
                    nlp.addActivationWithNegative(
                            t,
                            nlp.inferenceMode ? Range.add(t.doc, ak.pos, secondAct.key.pos) : ak.pos.intersection(secondAct.key.pos),
                            act.key.rid + refinement.getOffset(),
                            o,
                            Math.max(ak.fired, secondAct.key.fired),
                            nlp.inferenceMode ? Range.add(t.doc, addedRange, secondAct.key.pos) : addedRange.intersection(secondAct.key.pos),
                            iActs[0],
                            iActs[1]
                    );
                }
            }
        } else {
            if (conflict == null) {
                nlp.addActivationWithNegative(t, ak.pos, act.key.rid + refinement.getOffset(), ak.o, ak.fired, addedRange, act.inputs, Collections.singleton(act));
            }
        }
    }


    private void addActivationWithNegative(Iteration t, Range pos, int rid, Option o, Integer fired, Range addedRange, Set<Activation> inputActs, Set<Activation> directInputActs) {
        if(isPublic() && containsNegative()) {
            TreeSet<NegativeInputNode> negNodes = new TreeSet<>();
            for (Refinement ref : parents.keySet()) {
                if (ref.input instanceof NegativeInputNode) {
                    negNodes.add((NegativeInputNode) ref.input);
                }
            }
            TreeMap<Range, Set<Conflicts.Key>> conflicts = NegativeInputNode.getNegationSegments(t.doc, negNodes, addedRange);
            for (Map.Entry<Range, Set<Conflicts.Key>> me : conflicts.entrySet()) {
                Option no = retrieveInitialOption(me.getKey(), rid, o);
                if (no == null) {
                    no = Option.addPrimitive(t.doc, me.getKey().getBegin());

                    for (Conflicts.Key ck : me.getValue()) {
                        boolean isAllowed = false;
                        for (Activation pAct : directInputActs) {
                            if (pAct.key.n.isAllowedOption(ck.o, pAct, visitedCounter++)) isAllowed = true;
                        }
                        if (!isAllowed) {
                            Conflicts.add(t, ck.n, no, ck.o);
                        }
                    }
                }

                addActivation(t, new Key(this, me.getKey(), rid, o, fired), me.getKey(), no, inputActs, directInputActs);
            }
        } else {
            addActivation(t, new Key(this, pos, rid, o, fired), addedRange, null, inputActs, directInputActs);
        }
    }


    private static Set<Activation>[] prepareInputActs(Activation firstAct, Activation secondAct) {
        TreeSet<Activation> inputActs = new TreeSet<>();
        if(firstAct.inputs != null) {
            inputActs.addAll(firstAct.inputs);
        }
        if(secondAct.inputs != null) {
            inputActs.addAll(secondAct.inputs);
        }

        TreeSet<Activation> directInputActs = new TreeSet<>();
        directInputActs.add(firstAct);
        directInputActs.add(secondAct);
        return new Set[] {inputActs, directInputActs};
    }



    public static SortedMap<Refinement, LatticeNode> computeParents(Set<Refinement> refinements) {
        HashSet<LatticeNode.RSKey> visited = new HashSet<>();
        SortedMap<Refinement, LatticeNode> parents = new TreeMap<>();

        for(Refinement ref: refinements) {
            SortedSet<Refinement> childInputs = new TreeSet<>(refinements);
            childInputs.remove(ref);
            if(!ref.input.computeAndParents(ref.getRelativePosition(), childInputs, parents, visited)) {
                return null;
            }
        }

        return parents;
    }


    private static void prepareNextLevelPattern(Iteration t, int level, SortedMap<Refinement, LatticeNode> parents) {
        assert level == parents.size();

        Boolean im = null;
        for(Refinement ref: parents.keySet()) {
            if(ref.input.inputNeuron != null && ref.input.inputNeuron.isBlocked) return;

            if(im == null) {
                im = ref.inferenceMode;
            } else {
                assert im == ref.inferenceMode;
            }
        }

        AndNode nlp = new AndNode(t.m, level, parents, im);
        nlp.computePatternActivations(t, parents.values());
        t.addedNodes.add(nlp);
    }


    private Refinement getMinRefinement(Refinement exclude) {
        Refinement minRef = null;
        for(Refinement ref: parents.keySet()) {
            if(ref != exclude && (minRef == null || minRef.getRelativePosition() > ref.getRelativePosition())) {
                minRef = ref;
            }
        }
        return minRef;
    }


    @Override
    protected void collectNodeAndRefinements(Refinement newRef, Set<Refinement> inputs) {
        Refinement firstRef = null;
        Refinement secondRef = null;
        if(newRef.rid >= 0) {
            Refinement firstMinRef = getMinRefinement(null);
            firstRef = newRef.rid < firstMinRef.getRelativePosition() ? newRef : firstMinRef;

            Refinement secondMinRef = getMinRefinement(firstRef);
            secondRef = newRef != firstRef && newRef.rid < secondMinRef.getRelativePosition() ? newRef : secondMinRef;
        }

        // Since new rid is relative to the parent pattern, we need to figure out which refinement
        // has the lowest rid and compute the offset relative to the second lowest refinemnt rid.
        for(Refinement ref: parents.keySet()) {
            int nRid = ref == firstRef ? -secondRef.rid : ref.getRelativePosition() - Math.min(0, newRef.rid);
            inputs.add(new Refinement(nRid, ref.inferenceMode, ref.input));
        }
        inputs.add(newRef);
    }


    @Override
    public double computeSynapseWeightSum(Neuron n) {
        double sum = n.bias;
        for(Refinement ref: parents.keySet()) {
            Synapse s = n.inputSynapses.get(ref.input.inputNeuron);
            sum += Math.abs(s.w);
        }
        return sum;
    }


    private void computePatternActivations(Iteration t, Collection<LatticeNode> parentNodes) {
        Iterator<LatticeNode> it = parentNodes.iterator();
        Node firstParentNode = it.next();
        Node secondParentNode = it.next();

        if(firstParentNode instanceof NegativeInputNode) {
            Node tmp = firstParentNode;
            firstParentNode = secondParentNode;
            secondParentNode = tmp;
        }
        assert !(firstParentNode instanceof NegativeInputNode);

        for(Activation firstAct: firstParentNode.activations.values()) {
            for(Activation secondAct: Activation.select(secondParentNode, 0, firstAct.key.pos, Range.Relation.OVERLAPS, null, null, null, null, null, true)) {
                Option o = Option.add(t.doc, true, firstAct.key.o, secondAct.key.o);

                if(o != null) {
                    Set<Activation>[] iActs = prepareInputActs(firstAct, secondAct);

                    Range pos = firstAct.key.pos.intersection(secondAct.key.pos);
                    addActivation(t, new Key(this, pos, Math.min(firstAct.key.rid, secondAct.key.rid), o, Math.max(firstAct.key.fired, secondAct.key.fired)), pos, null, iActs[0], iActs[1]);
                }
            }
        }
    }


    @Override
    public void initActivation(Iteration t, Activation act) {
        if(activations.isEmpty()) {
            for (Map.Entry<Refinement, LatticeNode> me : parents.entrySet()) {
                me.getValue().andChildrenWithinDocument.put(me.getKey(), this);
            }
        }
    }


    @Override
    public void deleteActivation(Iteration t, Activation act) {
        if(activations.isEmpty()) {
            for(Map.Entry<Refinement, LatticeNode> me: parents.entrySet()) {
                me.getValue().andChildrenWithinDocument.remove(me.getKey());
            }
        }
    }


    @Override
    public void clearActivations() {
        super.clearActivations();

        for(Map.Entry<Refinement, LatticeNode> me: parents.entrySet()) {
            me.getValue().andChildrenWithinDocument.remove(me.getKey());
        }
    }



    @Override
    public void remove(Model m) {
        super.remove(m);

        for(Map.Entry<Refinement, LatticeNode> me: parents.entrySet()) {
            me.getValue().andChildren.remove(me.getKey());
        }
    }


    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("AND[");
        boolean first = true;
        for(Refinement ref: parents.keySet()) {
            if(!first) {
                sb.append(",");
            }
            first = false;
            sb.append(ref.rid);
            sb.append(":");
            sb.append(ref.input.logicToString());
        }
        sb.append("]");
        return sb.toString();
    }


    public static class Refinement implements Comparable<Refinement> {
        public final int rid;
        public final boolean inferenceMode;
        public final InputNode input;

        public Refinement(int rid, boolean inferenceMode, InputNode input) {
            this.rid = rid;
            this.inferenceMode = inferenceMode;
            this.input = input;
        }


        public int getOffset() {
            return Math.min(0, rid);
        }


        public int getRelativePosition() {
            return Math.max(0, rid);
        }


        public String toString() {
            return "[" + (inferenceMode ? "+" : "-") + rid + ":" + input.logicToString() + "]";
        }


        @Override
        public int compareTo(Refinement ref) {
            int r = Integer.compare(rid, ref.rid);
            if(r != 0) return r;
            r = Boolean.compare(inferenceMode, ref.inferenceMode);
            if(r != 0) return r;
            return input.compareTo(ref.input);
        }
    }
}
