/*
 * Decompiled with CFR 0.152.
 */
package org.aika.neuron;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.aika.Activation;
import org.aika.Model;
import org.aika.ReadWriteLock;
import org.aika.Utils;
import org.aika.Writable;
import org.aika.corpus.Document;
import org.aika.corpus.InterprNode;
import org.aika.corpus.Range;
import org.aika.corpus.SearchNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.OrNode;
import org.aika.neuron.InputNeuron;
import org.aika.neuron.Synapse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Neuron
implements Comparable<Neuron>,
Writable {
    private static final Logger log = LoggerFactory.getLogger(Neuron.class);
    public static final double LEARN_RATE = 0.01;
    public static final double WEIGHT_TOLERANCE = 0.001;
    public static final double TOLERANCE = 1.0E-6;
    public static final int MAX_SELF_REFERENCING_DEPTH = 5;
    public Model m;
    public static AtomicInteger currentNeuronId = new AtomicInteger(0);
    public int id = currentNeuronId.addAndGet(1);
    public String label;
    public volatile double bias;
    public volatile double negDirSum;
    public volatile double negRecSum;
    public volatile double posRecSum;
    public volatile double maxRecurrentSum = 0.0;
    public TreeSet<Synapse> outputSynapses = new TreeSet<Synapse>(Synapse.OUTPUT_SYNAPSE_COMP);
    public TreeSet<Synapse> inputSynapses = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
    public TreeSet<Synapse> inputSynapsesByWeight = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_BY_WEIGHTS_COMP);
    public TreeMap<Synapse.Key, InputNode> outputNodes = new TreeMap();
    public Node node;
    public boolean initialized = false;
    public boolean isBlocked;
    public boolean noTraining;
    public volatile double activationSum;
    public volatile int numberOfActivations;
    public ReadWriteLock lock = new ReadWriteLock();

    public Neuron() {
    }

    public Neuron(String label) {
        this.label = label;
    }

    public Neuron(String label, boolean isBlocked, boolean noTraining) {
        this.label = label;
        this.isBlocked = isBlocked;
        this.noTraining = noTraining;
    }

    public double avgActivation() {
        return (double)this.numberOfActivations > 0.0 ? this.activationSum / (double)this.numberOfActivations : 1.0;
    }

    public static Neuron create(Document doc, Neuron n, double bias, double negDirSum, double negRecSum, double posRecSum, Set<Synapse> inputs) {
        n.m = doc.m;
        ++n.m.stat.neurons;
        n.bias = bias;
        n.negDirSum = negDirSum;
        n.negRecSum = negRecSum;
        n.posRecSum = posRecSum;
        n.lock.acquireWriteLock(doc.threadId);
        n.node = new OrNode(doc);
        n.node.neuron = n;
        n.lock.releaseWriteLock();
        double sum = 0.0;
        for (Synapse s : inputs) {
            assert (!s.key.startRangeOutput || s.key.startRangeMatch == Range.Operator.EQUALS || s.key.startRangeMatch == Range.Operator.FIRST);
            assert (!s.key.endRangeOutput || s.key.endRangeMatch == Range.Operator.EQUALS || s.key.endRangeMatch == Range.Operator.FIRST);
            s.output = n;
            s.link(doc);
            if (s.maxLowerWeightsSum == Double.MAX_VALUE) {
                s.maxLowerWeightsSum = sum;
            }
            sum += s.w;
        }
        if (!Node.adjust(doc, n, -1)) {
            return null;
        }
        n.publish(doc);
        n.initialized = true;
        return n;
    }

    public void publish(Document doc) {
        this.m.neurons.put(this.id, this);
    }

    public void unpublish(Document doc) {
        this.m.neurons.remove(this);
    }

    public void remove(Document doc) {
        this.unpublish(doc);
        for (Synapse s : this.inputSynapses) {
            s.input.lock.acquireWriteLock(doc.threadId);
            s.input.outputSynapses.remove(s);
            s.input.lock.releaseWriteLock();
        }
        for (Synapse s : this.outputSynapses) {
            s.output.lock.acquireWriteLock(doc.threadId);
            s.output.inputSynapses.remove(s);
            s.output.inputSynapsesByWeight.remove(s);
            s.output.lock.releaseWriteLock();
        }
    }

    public void propagateAddedActivation(Document doc, Activation act) {
        doc.ubQueue.add(act);
    }

    public void propagateRemovedActivation(Document doc, Activation act) {
        for (InputNode out : this.outputNodes.values()) {
            out.removeActivation(doc, act);
        }
    }

    public void computeBounds(Activation act) {
        double ub = this.bias + this.posRecSum - (this.negDirSum + this.negRecSum);
        double lb = this.bias + this.posRecSum - (this.negDirSum + this.negRecSum);
        for (Activation.SynapseActivation sa : act.neuronInputs) {
            Synapse s = sa.s;
            Activation iAct = sa.input;
            if (iAct == act || iAct.isRemoved) continue;
            if (s.key.isNeg) {
                if (!Neuron.checkSelfReferencing(act.key.o, iAct.key.o, null, 0) && act.key.o.contains(iAct.key.o, true)) {
                    ub += iAct.lowerBound * s.w;
                }
                lb += s.w;
                continue;
            }
            ub += iAct.upperBound * s.w;
            lb += iAct.lowerBound * s.w;
        }
        act.upperBound = Neuron.transferFunction(ub);
        act.lowerBound = Neuron.transferFunction(lb);
    }

    public Activation.State computeWeight(int round, Activation act, SearchNode en, Document doc) {
        double directSum = this.bias - (this.negDirSum + this.negRecSum);
        double recurrentSum = 0.0;
        int fired = -1;
        if (round == 0) {
            recurrentSum += this.posRecSum;
        }
        ArrayList<Activation.SynapseActivation> tmp = new ArrayList<Activation.SynapseActivation>();
        Synapse lastSynapse = null;
        Activation.SynapseActivation maxSA = null;
        for (Activation.SynapseActivation sa : act.neuronInputs) {
            block14: {
                block13: {
                    if (lastSynapse != null && lastSynapse != sa.s) {
                        tmp.add(maxSA);
                        maxSA = null;
                    }
                    if (maxSA == null) break block13;
                    double d = maxSA.input.rounds.get((int)(sa.s.key.isRecurrent ? round - 1 : round)).value;
                    Activation.Rounds rounds = sa.input.rounds;
                    int n = sa.s.key.isRecurrent ? round - 1 : round;
                    if (!(d < rounds.get((int)n).value)) break block14;
                }
                maxSA = sa;
            }
            lastSynapse = sa.s;
        }
        if (maxSA != null) {
            tmp.add(maxSA);
        }
        for (Activation.SynapseActivation sa : tmp) {
            Activation.State is;
            Synapse s = sa.s;
            Activation iAct = sa.input;
            if (iAct == act || iAct.isRemoved) continue;
            if (s.key.isNeg && s.key.isRecurrent) {
                if (Neuron.checkSelfReferencing(act.key.o, iAct.key.o, en, 0)) continue;
                if (round == 0) {
                    if (!en.isCovered(iAct.key.o.markedCovered)) continue;
                    recurrentSum += s.w;
                    continue;
                }
                is = iAct.rounds.get(round - 1);
                recurrentSum += is.value * s.w;
                continue;
            }
            if (s.key.isNeg && !s.key.isRecurrent) {
                is = iAct.rounds.get(round);
                directSum += is.value * s.w;
                continue;
            }
            if (!s.key.isNeg && s.key.isRecurrent) {
                is = iAct.rounds.get(round - 1);
                recurrentSum += is.value * s.w;
                continue;
            }
            if (s.key.isNeg || s.key.isRecurrent) continue;
            is = iAct.rounds.get(round);
            if (!((directSum += is.value * s.w) + recurrentSum >= 0.0) || fired >= 0) continue;
            fired = iAct.rounds.get((int)round).fired + 1;
        }
        boolean covered = en.isCovered(act.key.o.markedCovered);
        double sum = directSum + recurrentSum;
        NormWeight newWeight = NormWeight.create(covered ? (directSum + this.negRecSum < 0.0 ? Math.max(0.0, sum) : recurrentSum - this.negRecSum) : 0.0, directSum + this.negRecSum < 0.0 ? Math.max(0.0, directSum + this.negRecSum + this.maxRecurrentSum) : this.maxRecurrentSum);
        if (doc.debugActId == act.id && doc.debugActWeight <= newWeight.w) {
            this.storeDebugOutput(doc, tmp, newWeight, sum, round);
        }
        return new Activation.State(covered ? Neuron.transferFunction(sum) : 0.0, covered ? fired : -1, newWeight);
    }

    private void storeDebugOutput(Document doc, List<Activation.SynapseActivation> inputs, NormWeight nw, double sum, int round) {
        StringBuilder sb = new StringBuilder();
        sb.append("Activation ID: " + doc.debugActId + "\n");
        sb.append("Neuron: " + this.label + "\n");
        sb.append("Sum: " + sum + "\n");
        sb.append("Bias: " + this.bias + "\n");
        sb.append("Round: " + round + "\n");
        sb.append("Positive Recurrent Sum: " + this.posRecSum + "\n");
        sb.append("Negative Recurrent Sum: " + this.negRecSum + "\n");
        sb.append("Negative Direct Sum: " + this.negDirSum + "\n");
        sb.append("Inputs:\n");
        for (Activation.SynapseActivation sa : inputs) {
            String actValue = "";
            if (sa.s.key.isRecurrent) {
                if (round > 0) {
                    actValue = "" + sa.input.rounds.get(round - 1);
                }
            } else {
                actValue = "" + sa.input.rounds.get(round);
            }
            sb.append("    " + sa.input.key.n.neuron.label + "  SynWeight: " + sa.s.w + "  ActValue: " + actValue);
            sb.append("\n");
        }
        sb.append("Weight: " + nw.w + "\n");
        sb.append("Norm: " + nw.n + "\n");
        sb.append("\n");
        doc.debugOutput = sb.toString();
    }

    public void computeErrorSignal(Document doc, Activation act) {
        act.errorSignal = act.initialErrorSignal;
        for (Activation.SynapseActivation sa : act.neuronOutputs) {
            Synapse s = sa.s;
            Activation oAct = sa.output;
            act.errorSignal += s.w * oAct.errorSignal * (1.0 - act.finalState.value);
        }
        for (Activation.SynapseActivation sa : act.neuronInputs) {
            doc.bQueue.add(sa.input);
        }
    }

    public void train(Document doc, Activation act) {
        if (Math.abs(act.errorSignal) < 1.0E-6) {
            return;
        }
        long v = Activation.visitedCounter++;
        Range targetRange = null;
        if (act.key.r != null) {
            int s = act.key.r.end - act.key.r.begin;
            targetRange = new Range(Math.max(0, act.key.r.begin - s / 2), Math.min(doc.length(), act.key.r.end + s / 2));
        }
        ArrayList<Activation> inputActs = new ArrayList<Activation>();
        for (Activation iAct : doc.inputNodeActivations) {
            if (!Range.overlaps(iAct.key.r, targetRange)) continue;
            inputActs.add(iAct);
        }
        if (Document.TRAIN_DEBUG_OUTPUT) {
            log.info("Debug discover:");
            log.info("Old Synapses:");
            for (Synapse s : this.inputSynapsesByWeight) {
                log.info("S:" + s.input + " RID:" + s.key.relativeRid + " W:" + s.w);
            }
            log.info("");
        }
        for (Activation.SynapseActivation sa : act.neuronInputs) {
            inputActs.add(sa.input);
        }
        for (Activation iAct : inputActs) {
            Integer rid = Utils.nullSafeSub(iAct.key.rid, false, act.key.rid, false);
            this.train(doc, iAct, rid, 0.01 * act.errorSignal, v);
        }
        if (Document.TRAIN_DEBUG_OUTPUT) {
            log.info("");
        }
        Node.adjust(doc, this, act.errorSignal > 0.0 ? 1 : -1);
    }

    public void train(Document doc, Activation iAct, Integer rid, double x, long v) {
        if (iAct.visitedNeuronTrain == v) {
            return;
        }
        iAct.visitedNeuronTrain = v;
        Activation iiAct = iAct.inputs.firstEntry().getValue();
        if (iiAct.key.n.neuron != this && iiAct.finalState != null && iiAct.finalState.value > 1.0E-6) {
            InputNode in = (InputNode)iAct.key.n;
            double deltaW = x * iiAct.finalState.value;
            InputNode.SynapseKey sk = new InputNode.SynapseKey(rid, this);
            Synapse s = in.getSynapse(sk);
            if (s == null) {
                s = new Synapse(iiAct.key.n.neuron, new Synapse.Key(in.key.isNeg, in.key.isRecurrent, rid, null, in.key.startRangeMatch, in.key.startRangeMapping, in.key.startRangeOutput, in.key.endRangeMatch, in.key.endRangeMapping, in.key.endRangeOutput));
                s.output = this;
                in.setSynapse(doc, sk, s);
                s.link(doc);
            }
            this.inputSynapses.remove(s);
            this.inputSynapsesByWeight.remove(s);
            double oldW = s.w;
            s.w -= deltaW;
            if (Document.TRAIN_DEBUG_OUTPUT) {
                log.info("S:" + s.input + " RID:" + s.key.relativeRid + " OldW:" + oldW + " NewW:" + s.w);
            }
            this.inputSynapses.add(s);
            this.inputSynapsesByWeight.add(s);
        }
    }

    private static boolean checkSelfReferencing(InterprNode nx, InterprNode ny, SearchNode en, int depth) {
        if (nx == ny && (en == null || en.isCovered(ny.markedCovered))) {
            return true;
        }
        if (depth > 5) {
            return false;
        }
        if (ny.orInterprNodes != null) {
            for (InterprNode n : ny.orInterprNodes.values()) {
                if (!Neuron.checkSelfReferencing(nx, n, en, depth + 1)) continue;
                return true;
            }
        }
        return false;
    }

    public static double transferFunction(double x) {
        return x > 0.0 ? 2.0 * Neuron.sigmoid(x) - 1.0 : 0.0;
    }

    public static double sigmoid(double x) {
        return 1.0 / (1.0 + Math.pow(Math.E, -x));
    }

    public void count(Document doc) {
        Node.ThreadState th = this.node.getThreadState(doc, false);
        if (th == null) {
            return;
        }
        for (Activation act : th.activations.values()) {
            if (act.finalState == null || !(act.finalState.value > 0.0)) continue;
            this.activationSum += act.finalState.value;
            ++this.numberOfActivations;
        }
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeBoolean(this instanceof InputNeuron);
        out.writeInt(this.id);
        out.writeUTF(this.label);
        out.writeDouble(this.bias);
        out.writeDouble(this.negDirSum);
        out.writeDouble(this.negRecSum);
        out.writeDouble(this.posRecSum);
        out.writeInt(this.outputNodes.size());
        for (Map.Entry<Synapse.Key, InputNode> me : this.outputNodes.entrySet()) {
            me.getKey().write(out);
            out.writeInt(me.getValue().id);
        }
        out.writeBoolean(this.node != null);
        if (this.node != null) {
            out.writeInt(this.node.id);
        }
        out.writeBoolean(this.isBlocked);
        out.writeBoolean(this.noTraining);
        out.writeDouble(this.activationSum);
        out.writeInt(this.numberOfActivations);
    }

    @Override
    public void readFields(DataInput in, Document doc) throws IOException {
        this.id = in.readInt();
        this.label = in.readUTF();
        this.bias = in.readDouble();
        this.negDirSum = in.readDouble();
        this.negRecSum = in.readDouble();
        this.posRecSum = in.readDouble();
        int s = in.readInt();
        for (int i = 0; i < s; ++i) {
            Synapse.Key k = Synapse.Key.read(in, doc);
            InputNode n = (InputNode)doc.m.initialNodes.get(in.readInt());
            this.outputNodes.put(k, n);
            n.inputNeuron = this;
        }
        if (in.readBoolean()) {
            this.node = doc.m.initialNodes.get(in.readInt());
            this.node.neuron = this;
        }
        this.isBlocked = in.readBoolean();
        this.noTraining = in.readBoolean();
        this.activationSum = in.readDouble();
        this.numberOfActivations = in.readInt();
    }

    public static Neuron read(DataInput in, Document doc) throws IOException {
        Neuron n = in.readBoolean() ? new InputNeuron() : new Neuron();
        n.readFields(in, doc);
        return n;
    }

    @Override
    public int compareTo(Neuron n) {
        if (this.id < n.id) {
            return -1;
        }
        if (this.id > n.id) {
            return 1;
        }
        return 0;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("n(");
        sb.append(this.id);
        if (this.label != null) {
            sb.append(",");
            sb.append(this.label);
        }
        sb.append(")");
        return sb.toString();
    }

    public String toStringWithSynapses() {
        TreeSet<Synapse> is = new TreeSet<Synapse>(new Comparator<Synapse>(){

            @Override
            public int compare(Synapse s1, Synapse s2) {
                int r = Double.compare(s2.w, s1.w);
                if (r != 0) {
                    return r;
                }
                return Integer.compare(s1.input.id, s2.input.id);
            }
        });
        is.addAll(this.inputSynapsesByWeight);
        StringBuilder sb = new StringBuilder();
        sb.append(this.toString());
        sb.append("<");
        sb.append("B:");
        sb.append(Utils.round(this.bias));
        for (Synapse s : is) {
            sb.append(", ");
            sb.append(Utils.round(s.w));
            sb.append(":");
            sb.append(s.key.relativeRid);
            sb.append(":");
            sb.append(s.input.toString());
        }
        sb.append(">");
        return sb.toString();
    }

    public Collection<Activation> getFinalActivations(Document doc) {
        return Activation.select(doc, this.node, null, null, null, null, null, null).filter(act -> act.finalState != null && act.finalState.value > 0.0).collect(Collectors.toList());
    }

    public static class NormWeight {
        public static final NormWeight ZERO_WEIGHT = new NormWeight(0.0, 0.0);
        public final double w;
        public final double n;

        private NormWeight(double w, double n) {
            this.w = w;
            this.n = n;
        }

        public static NormWeight create(double w, double n) {
            if (w == 0.0 && n == 0.0) {
                return ZERO_WEIGHT;
            }
            return new NormWeight(w, n);
        }

        public NormWeight add(NormWeight nw) {
            if (nw == null || nw == ZERO_WEIGHT) {
                return this;
            }
            return new NormWeight(this.w + nw.w, this.n + nw.n);
        }

        public NormWeight sub(NormWeight nw) {
            if (nw == null || nw == ZERO_WEIGHT) {
                return this;
            }
            return new NormWeight(this.w - nw.w, this.n - nw.n);
        }

        public double getNormWeight() {
            return this.n > 0.0 ? this.w / this.n : 0.0;
        }

        public boolean equals(NormWeight nw) {
            return Math.abs(this.w - nw.w) <= 0.001 && Math.abs(this.n - nw.n) <= 0.001;
        }

        public String toString() {
            return "W:" + this.w + " N:" + this.n + " NW:" + this.getNormWeight();
        }
    }
}

