/*
 * Decompiled with CFR 0.152.
 */
package org.aika.neuron;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.aika.AbstractNode;
import org.aika.ActivationFunction;
import org.aika.Converter;
import org.aika.Model;
import org.aika.Provider;
import org.aika.ReadWriteLock;
import org.aika.Utils;
import org.aika.Writable;
import org.aika.corpus.Conflicts;
import org.aika.corpus.Document;
import org.aika.corpus.InterpretationNode;
import org.aika.corpus.SearchNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.Neuron;
import org.aika.neuron.Synapse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class INeuron
extends AbstractNode<Neuron, Activation>
implements Comparable<INeuron> {
    public static boolean ALLOW_WEAK_NEGATIVE_WEIGHTS = false;
    private static final Logger log = LoggerFactory.getLogger(INeuron.class);
    public static double WEIGHT_TOLERANCE = 0.001;
    public static double TOLERANCE = 1.0E-6;
    public String label;
    public Type type;
    public String outputText;
    public volatile double bias;
    public volatile double biasDelta;
    public volatile double biasSum;
    public volatile double biasSumDelta;
    public volatile double metaBias = 0.0;
    public volatile double posDirSum;
    public volatile double negDirSum;
    public volatile double negRecSum;
    public volatile double posRecSum;
    public volatile double maxRecurrentSum = 0.0;
    public Writable statistic;
    public ActivationFunction activationFunction = ActivationFunction.RECTIFIED_SCALED_LOGISTIC_SIGMOID;
    public String activationFunctionKey = "ReSLS";
    public TreeMap<Synapse, Synapse> inputSynapses = new TreeMap(Synapse.INPUT_SYNAPSE_COMP);
    public TreeMap<Synapse, Synapse> outputSynapses = new TreeMap(Synapse.OUTPUT_SYNAPSE_COMP);
    public TreeMap<Synapse.Key, Provider<InputNode>> outputNodes = new TreeMap();
    public Provider<OrNode> node;
    public ReadWriteLock lock = new ReadWriteLock();
    public ThreadState[] threads;

    public ThreadState getThreadState(int threadId, boolean create) {
        ThreadState th = this.threads[threadId];
        if (th == null) {
            if (!create) {
                return null;
            }
            this.threads[threadId] = th = new ThreadState();
        }
        th.lastUsed = ((Neuron)this.provider).model.docIdCounter.get();
        return th;
    }

    private INeuron() {
    }

    public INeuron(Model m) {
        this(m, null);
    }

    public INeuron(Model m, String label) {
        this(m, label, null);
    }

    public INeuron(Model m, String label, String outputText) {
        this.label = label;
        this.outputText = outputText;
        if (m.neuronStatisticFactory != null) {
            this.statistic = m.neuronStatisticFactory.createStatisticObject();
        }
        this.threads = new ThreadState[m.numberOfThreads];
        this.provider = new Neuron(m, this);
        OrNode node = new OrNode(m);
        node.neuron = (Neuron)this.provider;
        this.node = node.provider;
        this.setModified();
    }

    public Activation addInput(Document doc, Activation.Builder input) {
        InterpretationNode interpr = input.interpretation != null ? input.interpretation : doc.bottom;
        NodeActivation.Key<Node> ak = new NodeActivation.Key<Node>(this.node.get(doc), input.range, input.rid, interpr);
        Activation act = this.node.get(doc).createActivation(doc, ak);
        this.register(act);
        Activation.State s = new Activation.State(input.value, input.fired, SearchNode.Weight.ZERO);
        act.rounds.set(0, s);
        act.inputValue = input.value;
        act.upperBound = input.value;
        act.lowerBound = input.value;
        act.setTargetValue(input.targetValue);
        doc.inputNeuronActivations.add(act);
        doc.finallyActivatedNeurons.add(act.getINeuron());
        this.propagate(act);
        doc.propagate();
        return act;
    }

    public void remove() {
        this.clearActivations();
        for (Synapse s : this.inputSynapses.values()) {
            INeuron in = (INeuron)s.input.get();
            ((Neuron)in.provider).lock.acquireWriteLock();
            ((Neuron)in.provider).inMemoryOutputSynapses.remove(s);
            ((Neuron)in.provider).lock.releaseWriteLock();
        }
        ((Neuron)this.provider).lock.acquireReadLock();
        for (Synapse s : ((Neuron)this.provider).inMemoryOutputSynapses.values()) {
            INeuron out = (INeuron)s.output.get();
            out.lock.acquireWriteLock();
            out.inputSynapses.remove(s);
            out.lock.releaseWriteLock();
        }
        ((Neuron)this.provider).lock.releaseReadLock();
    }

    @Override
    public void propagate(Activation act) {
        Document doc = act.doc;
        for (Provider<InputNode> out : this.outputNodes.values()) {
            out.get(doc).addActivation(doc, act);
        }
    }

    public void linkActivation(Activation act) {
        long v = act.doc.visitedCounter++;
        this.lock.acquireReadLock();
        this.linkActivation(act, v, 0);
        this.linkActivation(act, v, 1);
        this.lock.releaseReadLock();
    }

    private void linkActivation(Activation act, long v, int dir) {
        ArrayList<Activation> recNegTmp = new ArrayList<Activation>();
        ((Neuron)this.provider).lock.acquireReadLock();
        NavigableMap<Synapse, Synapse> syns = dir == 0 ? ((Neuron)this.provider).inMemoryInputSynapses : ((Neuron)this.provider).inMemoryOutputSynapses;
        Document doc = act.doc;
        for (Synapse s : INeuron.getActiveSynapses(((Neuron)this.provider).model, doc, dir, syns)) {
            ThreadState th;
            Neuron p = dir == 0 ? s.input : s.output;
            INeuron an = (INeuron)p.getIfNotSuspended();
            if (an == null || (th = an.getThreadState(doc.threadId, false)) == null || th.activations.isEmpty()) continue;
            INeuron.linkActSyn(an, act, dir, recNegTmp, s);
        }
        ((Neuron)this.provider).lock.releaseReadLock();
        for (Activation rAct : recNegTmp) {
            Activation oAct = dir == 0 ? act : rAct;
            Activation iAct = dir == 0 ? rAct : act;
            INeuron.markConflicts(iAct, oAct, v);
            INeuron.addConflict(oAct.key.interpretation, iAct.key.interpretation, iAct, Collections.singleton(act), v);
        }
    }

    private static void addConflict(InterpretationNode io, InterpretationNode o, NodeActivation act, Collection<NodeActivation> inputActs, long v) {
        if (o.markedConflict == v || o.state == InterpretationNode.State.SELECTED) {
            if (!InterpretationNode.checkSelfReferencing(o, io, false, 0)) {
                Conflicts.add(act, io, o);
            }
        } else if (o.orInterpretationNodes != null) {
            for (InterpretationNode no : o.orInterpretationNodes) {
                INeuron.addConflict(io, no, act, inputActs, v);
            }
        }
    }

    private static void markConflicts(Activation iAct, Activation oAct, long v) {
        oAct.key.interpretation.markedConflict = v;
        for (Activation.SynapseActivation sa : iAct.neuronOutputs) {
            if (!sa.synapse.key.isRecurrent || !sa.synapse.isNegative()) continue;
            sa.output.key.interpretation.markedConflict = v;
        }
    }

    private static void linkActSyn(INeuron n, Activation act, int dir, ArrayList<Activation> recNegTmp, Synapse s) {
        Synapse.Key sk = s.key;
        Integer rid = dir == 0 ? (sk.absoluteRid != null ? sk.absoluteRid : Utils.nullSafeAdd(act.key.rid, false, sk.relativeRid, false)) : Utils.nullSafeSub(act.key.rid, false, sk.relativeRid, false);
        Stream<Activation> tmp = Activation.select(act.doc, n, rid, act.key.range, dir == 0 ? sk.rangeMatch.invert() : sk.rangeMatch, null, null);
        int d = dir;
        tmp.forEach(rAct -> {
            Activation oAct = d == 0 ? act : rAct;
            Activation iAct = d == 0 ? rAct : act;
            Activation.SynapseActivation sa = new Activation.SynapseActivation(s, iAct, oAct);
            iAct.addSynapseActivation(0, sa);
            oAct.addSynapseActivation(1, sa);
            if (s.isNegative() && key.isRecurrent) {
                recNegTmp.add((Activation)rAct);
            }
        });
    }

    private static Collection<Synapse> getActiveSynapses(Model m, Document doc, int dir, NavigableMap<Synapse, Synapse> syns) {
        if (syns.size() < 10 || doc.activatedNeurons.size() * 20 > syns.size()) {
            return syns.values();
        }
        ArrayList<Synapse> newSyns = new ArrayList<Synapse>();
        Synapse lk = new Synapse(null, null, Synapse.Key.MIN_KEY);
        Synapse uk = new Synapse(null, null, Synapse.Key.MAX_KEY);
        for (INeuron n : doc.activatedNeurons) {
            if (dir == 0) {
                lk.input = (Neuron)n.provider;
                uk.input = (Neuron)n.provider;
            } else {
                lk.output = (Neuron)n.provider;
                uk.output = (Neuron)n.provider;
            }
            for (Synapse s : syns.subMap(lk, true, uk, true).values()) {
                newSyns.add(s);
            }
        }
        ArrayList<Synapse> synsTmp = newSyns;
        return synsTmp;
    }

    public Collection<Activation> getActivations(Document doc) {
        ThreadState th = this.getThreadState(doc.threadId, false);
        if (th == null) {
            return Collections.EMPTY_LIST;
        }
        return th.activations.values();
    }

    public synchronized Activation getFirstActivation(Document doc) {
        ThreadState th = this.getThreadState(doc.threadId, false);
        if (th == null || th.activations.isEmpty()) {
            return null;
        }
        return th.activations.firstEntry().getValue();
    }

    public void clearActivations() {
        for (int i = 0; i < ((Neuron)this.provider).model.numberOfThreads; ++i) {
            this.clearActivations(i);
        }
    }

    public void clearActivations(Document doc) {
        this.clearActivations(doc.threadId);
    }

    public void clearActivations(int threadId) {
        ThreadState th = this.getThreadState(threadId, false);
        if (th == null) {
            return;
        }
        th.activations.clear();
        if (th.activationsEnd != null) {
            th.activationsEnd.clear();
        }
        if (th.activationsRid != null) {
            th.activationsRid.clear();
        }
    }

    @Override
    public int compareTo(INeuron n) {
        if (((Neuron)this.provider).id < ((Neuron)n.provider).id) {
            return -1;
        }
        if (((Neuron)this.provider).id > ((Neuron)n.provider).id) {
            return 1;
        }
        return 0;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeBoolean(true);
        out.writeBoolean(this.label != null);
        if (this.label != null) {
            out.writeUTF(this.label);
        }
        out.writeBoolean(this.type != null);
        if (this.type != null) {
            out.writeUTF(this.type.name());
        }
        out.writeBoolean(this.outputText != null);
        if (this.outputText != null) {
            out.writeUTF(this.outputText);
        }
        out.writeBoolean(this.statistic != null);
        if (this.statistic != null) {
            this.statistic.write(out);
        }
        out.writeDouble(this.bias);
        out.writeDouble(this.biasSum);
        out.writeDouble(this.posDirSum);
        out.writeDouble(this.negDirSum);
        out.writeDouble(this.negRecSum);
        out.writeDouble(this.posRecSum);
        out.writeDouble(this.maxRecurrentSum);
        out.writeUTF(this.activationFunctionKey);
        out.writeInt(this.outputNodes.size());
        for (Map.Entry<Synapse.Key, Provider<InputNode>> me : this.outputNodes.entrySet()) {
            me.getKey().write(out);
            out.writeInt(me.getValue().id);
        }
        out.writeBoolean(this.node != null);
        if (this.node != null) {
            out.writeInt(this.node.id);
        }
        for (Synapse s : this.inputSynapses.values()) {
            if (s.input == null) continue;
            out.writeBoolean(true);
            s.write(out);
        }
        out.writeBoolean(false);
        for (Synapse s : this.outputSynapses.values()) {
            if (s.output == null) continue;
            out.writeBoolean(true);
            s.write(out);
        }
        out.writeBoolean(false);
    }

    @Override
    public void readFields(DataInput in, Model m) throws IOException {
        if (in.readBoolean()) {
            this.label = in.readUTF();
        }
        if (in.readBoolean()) {
            this.type = Type.valueOf(in.readUTF());
        }
        if (in.readBoolean()) {
            this.outputText = in.readUTF();
        }
        if (in.readBoolean()) {
            this.statistic = m.neuronStatisticFactory.createStatisticObject();
            this.statistic.readFields(in, m);
        }
        this.bias = in.readDouble();
        this.biasSum = in.readDouble();
        this.posDirSum = in.readDouble();
        this.negDirSum = in.readDouble();
        this.negRecSum = in.readDouble();
        this.posRecSum = in.readDouble();
        this.maxRecurrentSum = in.readDouble();
        this.activationFunctionKey = in.readUTF();
        this.activationFunction = m.activationFunctions.get(this.activationFunctionKey);
        int s = in.readInt();
        for (int i = 0; i < s; ++i) {
            Synapse.Key k = Synapse.Key.read(in, m);
            Object n = m.lookupNodeProvider(in.readInt());
            this.outputNodes.put(k, (Provider<InputNode>)n);
        }
        if (in.readBoolean()) {
            Integer nId = in.readInt();
            this.node = m.lookupNodeProvider(nId);
        }
        while (in.readBoolean()) {
            Synapse syn = Synapse.read(in, m);
            this.inputSynapses.put(syn, syn);
        }
        while (in.readBoolean()) {
            Synapse syn = Synapse.read(in, m);
            this.outputSynapses.put(syn, syn);
        }
    }

    @Override
    public void suspend() {
        for (Synapse s : this.inputSynapses.values()) {
            s.input.removeInMemoryOutputSynapse(s);
        }
        for (Synapse s : this.outputSynapses.values()) {
            s.output.removeInMemoryInputSynapse(s);
        }
        ((Neuron)this.provider).lock.acquireReadLock();
        for (Synapse s : ((Neuron)this.provider).inMemoryInputSynapses.values()) {
            if (s.isConjunction) continue;
            s.input.removeInMemoryOutputSynapse(s);
        }
        for (Synapse s : ((Neuron)this.provider).inMemoryOutputSynapses.values()) {
            if (!s.isConjunction) continue;
            s.output.removeInMemoryInputSynapse(s);
        }
        ((Neuron)this.provider).lock.releaseReadLock();
    }

    @Override
    public void reactivate() {
        ((Neuron)this.provider).lock.acquireReadLock();
        for (Synapse s : ((Neuron)this.provider).inMemoryInputSynapses.values()) {
            if (s.isConjunction) continue;
            s.input.addInMemoryOutputSynapse(s);
        }
        for (Synapse s : ((Neuron)this.provider).inMemoryOutputSynapses.values()) {
            if (!s.isConjunction) continue;
            s.output.addInMemoryInputSynapse(s);
        }
        ((Neuron)this.provider).lock.releaseReadLock();
        for (Synapse s : this.inputSynapses.values()) {
            s.input.addInMemoryOutputSynapse(s);
            if (s.input.isSuspended()) continue;
            s.output.addInMemoryInputSynapse(s);
        }
        for (Synapse s : this.outputSynapses.values()) {
            s.output.addInMemoryInputSynapse(s);
            if (s.output.isSuspended()) continue;
            s.input.addInMemoryOutputSynapse(s);
        }
    }

    public void setBias(double b) {
        double newBiasDelta = b - this.bias;
        this.biasSumDelta += newBiasDelta - this.biasDelta;
        this.biasDelta = newBiasDelta;
    }

    public void changeBias(double bd) {
        this.biasDelta += bd;
        this.biasSumDelta += bd;
    }

    public double getNewBiasSum() {
        return this.biasSum + this.biasSumDelta;
    }

    public void register(Activation act) {
        NodeActivation.Key ak = act.key;
        Document doc = act.doc;
        ThreadState th = ((INeuron)((OrNode)ak.node).neuron.get()).getThreadState(doc.threadId, true);
        if (!th.activations.containsKey(ak)) {
            TreeMap<NodeActivation.Key, Activation> actRid;
            if (th.activations.isEmpty()) {
                doc.activatedNeurons.add((INeuron)((OrNode)ak.node).neuron.get());
            }
            th.activations.put(ak, act);
            TreeMap<NodeActivation.Key, Activation> actEnd = th.activationsEnd;
            if (actEnd != null) {
                actEnd.put(ak, act);
            }
            if ((actRid = th.activationsRid) != null) {
                actRid.put(ak, act);
            }
            if (ak.rid != null) {
                doc.activationsByRid.put(ak, act);
            }
            doc.addedActivations.add(act);
        }
        this.linkActivation(act);
    }

    public static boolean update(Model m, int threadId, Document doc, Neuron pn, double biasDelta, Collection<Synapse> modifiedSynapses) {
        INeuron n = (INeuron)pn.get();
        n.changeBias(biasDelta);
        modifiedSynapses.forEach(s -> s.link());
        return Converter.convert(m, threadId, doc, n, modifiedSynapses);
    }

    public static INeuron readNeuron(DataInput in, Neuron p) throws IOException {
        INeuron n = new INeuron();
        n.provider = p;
        n.threads = new ThreadState[p.model.numberOfThreads];
        n.readFields(in, p.model);
        return n;
    }

    public String toString() {
        return this.label;
    }

    public String toStringWithSynapses() {
        TreeSet<Synapse> is = new TreeSet<Synapse>((s1, s2) -> {
            int r = Double.compare(s2.weight, s1.weight);
            if (r != 0) {
                return r;
            }
            return Integer.compare(s1.input.id, s2.input.id);
        });
        is.addAll(this.inputSynapses.values());
        StringBuilder sb = new StringBuilder();
        sb.append(this.toString());
        sb.append("<");
        sb.append("B:");
        sb.append(Utils.round(this.biasSum));
        for (Synapse s : is) {
            sb.append(", ");
            sb.append(Utils.round(s.weight));
            sb.append(":");
            sb.append(s.key.relativeRid);
            sb.append(":");
            sb.append(s.input.toString());
        }
        sb.append(">");
        return sb.toString();
    }

    public Stream<Activation> getFinalActivationsStream(Document doc) {
        return this.getActivationsStream(doc).filter(act -> act.isFinalActivation());
    }

    public Stream<Activation> getActivationsStream(Document doc) {
        return Activation.select(doc, this, null, null, null, null, null);
    }

    public Collection<Activation> getFinalActivations(Document doc) {
        return this.getFinalActivationsStream(doc).collect(Collectors.toList());
    }

    public Collection<Activation> getAllActivations(Document doc) {
        Stream<Activation> s = Activation.select(doc, this, null, null, null, null, null);
        return s.collect(Collectors.toList());
    }

    public static class ThreadState {
        public long lastUsed;
        public TreeMap<NodeActivation.Key, Activation> activations = new TreeMap(Node.BEGIN_COMP);
        public TreeMap<NodeActivation.Key, Activation> activationsEnd = new TreeMap(Node.END_COMP);
        public TreeMap<NodeActivation.Key, Activation> activationsRid = new TreeMap(Node.RID_COMP);
    }

    public static enum Type {
        EXCITATORY,
        INHIBITORY,
        META;

    }
}

