/*
 * Decompiled with CFR 0.152.
 */
package network.aika.lattice;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import network.aika.AbstractNode;
import network.aika.Document;
import network.aika.Model;
import network.aika.lattice.Node;
import network.aika.lattice.OrNode;
import network.aika.lattice.refinement.RefValue;
import network.aika.lattice.refinement.Refinement;
import network.aika.lattice.refinement.RelationsMap;
import network.aika.neuron.INeuron;
import network.aika.neuron.Synapse;
import network.aika.neuron.relation.Relation;

public class Converter {
    public static int MAX_AND_NODE_SIZE = 10;
    public static Comparator<Synapse> SYNAPSE_COMP = (s1, s2) -> {
        int r = Boolean.compare(s2.linksAnyOutput() || s2.isIdentity(), s1.linksAnyOutput() || s1.isIdentity());
        if (r != 0) {
            return r;
        }
        r = Double.compare(s2.getWeight(), s1.getWeight());
        if (r != 0) {
            return r;
        }
        return Integer.compare(s1.getId(), s2.getId());
    };
    private int threadId;
    private INeuron neuron;
    private Model model;
    private Document doc;
    private OrNode outputNode;
    private Collection<Synapse> modifiedSynapses;

    public static boolean convert(int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
        return new Converter(threadId, doc, neuron, modifiedSynapses).convert();
    }

    private Converter(int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
        this.doc = doc;
        this.neuron = neuron;
        this.model = neuron.getModel();
        this.threadId = threadId;
        this.modifiedSynapses = modifiedSynapses;
    }

    private boolean convert() {
        this.setModelLabel(this.neuron);
        this.outputNode = this.neuron.getInputNode().get();
        this.setModelLabel(this.outputNode);
        INeuron.SynapseSummary ss = this.neuron.getSynapseSummary();
        if (this.neuron.getTotalBias(Synapse.State.CURRENT) + ss.getPosDirSum() + ss.getPosRecSum() <= 0.0) {
            this.outputNode.removeParents(this.threadId);
            return false;
        }
        switch (this.neuron.getType()) {
            case EXCITATORY: {
                if (this.hasOnlyWeakSynapses()) {
                    this.convertWeakSynapses();
                    break;
                }
                this.convertConjunction();
                break;
            }
            case INHIBITORY: {
                this.convertDisjunction();
                break;
            }
        }
        return true;
    }

    private void setModelLabel(AbstractNode n) {
        if (this.model.getModelLabelCallback() != null) {
            n.addModelLabel(this.model.getModelLabelCallback().getCurrentModelLabel());
        }
    }

    private void convertConjunction() {
        INeuron.SynapseSummary ss = this.neuron.getSynapseSummary();
        this.outputNode.removeParents(this.threadId);
        List<Synapse> candidates = this.prepareCandidates();
        double sum = 0.0;
        NodeContext nodeContext = null;
        double remainingSum = ss.getPosDirSum();
        int i = 0;
        boolean optionalInputMode = false;
        for (Synapse s : candidates) {
            boolean maxAndNodesReached;
            NodeContext nlNodeContext;
            boolean belowThreshold;
            double v = s.getMaxInputValue();
            boolean bl = belowThreshold = sum + v + remainingSum + ss.getPosRecSum() + ss.getPosPassiveSum() + this.neuron.getTotalBias(Synapse.State.CURRENT) <= 0.0;
            if (belowThreshold) {
                return;
            }
            if (sum + remainingSum - v + ss.getPosRecSum() + ss.getPosPassiveSum() + this.neuron.getTotalBias(Synapse.State.CURRENT) > 0.0) {
                optionalInputMode = true;
            }
            if (!optionalInputMode) {
                nlNodeContext = this.expandNode(nodeContext, s);
                if (nlNodeContext == null) {
                    return;
                }
                nodeContext = nlNodeContext;
                remainingSum -= v;
                sum += v;
                ++i;
            } else {
                nlNodeContext = this.expandNode(nodeContext, s);
                if (nlNodeContext != null) {
                    this.outputNode.addInput(nlNodeContext.getSynapseIds(), this.threadId, nlNodeContext.node, true);
                    remainingSum -= v;
                }
            }
            boolean sumOfSynapseWeightsAboveThreshold = sum + ss.getPosRecSum() + ss.getPosPassiveSum() + this.neuron.getTotalBias(Synapse.State.CURRENT) > 0.0;
            boolean bl2 = maxAndNodesReached = i >= MAX_AND_NODE_SIZE;
            if (!sumOfSynapseWeightsAboveThreshold && !maxAndNodesReached) continue;
            break;
        }
        if (nodeContext != null && !optionalInputMode) {
            this.outputNode.addInput(nodeContext.getSynapseIds(), this.threadId, nodeContext.node, true);
        }
    }

    private boolean hasOnlyWeakSynapses() {
        for (Synapse s : this.neuron.getInputSynapses()) {
            if (s.isWeak(Synapse.State.CURRENT)) continue;
            return false;
        }
        return true;
    }

    private void convertWeakSynapses() {
        TreeSet<Synapse> synapsesSortedByWeight = new TreeSet<Synapse>((s1, s2) -> {
            int r = Double.compare(s2.getWeight(), s1.getWeight());
            if (r != 0) {
                return r;
            }
            return SYNAPSE_COMP.compare((Synapse)s1, (Synapse)s2);
        });
        synapsesSortedByWeight.addAll(this.neuron.getInputSynapses());
        double sum = 0.0;
        for (Synapse s : synapsesSortedByWeight) {
            if (s.isRecurrent()) continue;
            sum += s.getWeight();
            NodeContext nlNodeContext = this.expandNode(null, s);
            this.outputNode.addInput(nlNodeContext.getSynapseIds(), this.threadId, nlNodeContext.node, true);
            if (!(sum > this.neuron.getBias())) continue;
        }
    }

    private void convertDisjunction() {
        for (Synapse s : this.modifiedSynapses) {
            if (s.isRecurrent() || s.isWeak(Synapse.State.CURRENT)) continue;
            NodeContext nlNodeContext = this.expandNode(null, s);
            this.outputNode.addInput(nlNodeContext.getSynapseIds(), this.threadId, nlNodeContext.node, false);
        }
    }

    private List<Synapse> prepareCandidates() {
        Synapse syn = this.getStrongestSynapse(this.neuron.getInputSynapses());
        if (syn == null) {
            return Collections.EMPTY_LIST;
        }
        TreeSet<Integer> alreadyCollected = new TreeSet<Integer>();
        ArrayList<Synapse> selectedCandidates = new ArrayList<Synapse>();
        TreeMap<Integer, Synapse> relatedSyns = new TreeMap<Integer, Synapse>();
        while (syn != null && selectedCandidates.size() < MAX_AND_NODE_SIZE) {
            relatedSyns.remove(syn.getId());
            selectedCandidates.add(syn);
            alreadyCollected.add(syn.getId());
            for (Map.Entry<Integer, Relation> me : syn.getRelations().entrySet()) {
                Synapse rs;
                Integer relId = me.getKey();
                Relation rel = me.getValue();
                if (alreadyCollected.contains(relId) || (rs = syn.getOutput().getSynapseById(relId)) == null) continue;
                relatedSyns.put(relId, rs);
            }
            syn = this.getStrongestSynapse(relatedSyns.values());
        }
        return selectedCandidates;
    }

    private Synapse getStrongestSynapse(Collection<Synapse> synapses) {
        Synapse maxSyn = null;
        for (Synapse s : synapses) {
            if (s.isNegative(Synapse.State.CURRENT) || s.isRecurrent() || s.isInactive() || ((INeuron)s.getInput().get()).isPassiveInputNeuron() || maxSyn != null && SYNAPSE_COMP.compare(maxSyn, s) <= 0) continue;
            maxSyn = s;
        }
        return maxSyn;
    }

    private NodeContext expandNode(NodeContext nc, Synapse s) {
        NodeContext nln = this.expandNodeInternal(nc, s);
        if (nln != null) {
            this.setModelLabel(nln.node);
        }
        return nln;
    }

    private NodeContext expandNodeInternal(NodeContext nc, Synapse s) {
        int i;
        if (nc == null) {
            NodeContext nln = new NodeContext();
            nln.node = ((INeuron)s.getInput().get()).getOutputNode().get();
            nln.offsets = new Synapse[]{s};
            return nln;
        }
        Relation[] relations = new Relation[nc.offsets.length];
        for (int i2 = 0; i2 < nc.offsets.length; ++i2) {
            Synapse linkedSynapse = nc.offsets[i2];
            relations[i2] = s.getRelationById(linkedSynapse.getId());
        }
        NodeContext nln = new NodeContext();
        nln.offsets = new Synapse[nc.offsets.length + 1];
        Refinement ref = new Refinement(new RelationsMap(relations), ((INeuron)s.getInput().get()).getOutputNode());
        RefValue rv = nc.node.expand(this.threadId, this.doc, ref);
        if (rv == null) {
            return null;
        }
        nln.node = rv.child.get(this.doc);
        for (i = 0; i < nc.offsets.length; ++i) {
            nln.offsets[rv.offsets[i].intValue()] = nc.offsets[i];
        }
        for (i = 0; i < nln.offsets.length; ++i) {
            if (nln.offsets[i] != null) continue;
            nln.offsets[i] = s;
        }
        return nln;
    }

    private class NodeContext {
        Node node;
        Synapse[] offsets;

        private NodeContext() {
        }

        int[] getSynapseIds() {
            int[] result = new int[this.offsets.length];
            for (int i = 0; i < result.length; ++i) {
                result[i] = this.offsets[i].getId();
            }
            return result;
        }
    }
}

