/*
 * Decompiled with CFR 0.152.
 */
package org.aika;

import java.util.Collection;
import java.util.Comparator;
import java.util.TreeSet;
import org.aika.Model;
import org.aika.Utils;
import org.aika.lattice.AndNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;

public class Converter {
    public static int MAX_AND_NODE_SIZE = 4;
    public static Comparator<Synapse> SYNAPSE_COMP = (s1, s2) -> {
        int r = Double.compare(s2.weight, s1.weight);
        if (r != 0) {
            return r;
        }
        return Synapse.INPUT_SYNAPSE_COMP.compare((Synapse)s1, (Synapse)s2);
    };
    private Model model;
    private int threadId;
    private INeuron neuron;
    private OrNode outputNode;
    private Collection<Synapse> modifiedSynapses;
    public static final int DIRECT = 0;
    public static final int RECURRENT = 1;
    public static final int POSITIVE = 0;
    public static final int NEGATIVE = 1;

    public static boolean convert(Model m, int threadId, INeuron neuron, Collection<Synapse> modifiedSynapses) {
        return new Converter(m, threadId, neuron, modifiedSynapses).convert();
    }

    private Converter(Model model, int threadId, INeuron neuron, Collection<Synapse> modifiedSynapses) {
        this.model = model;
        this.neuron = neuron;
        this.threadId = threadId;
        this.modifiedSynapses = modifiedSynapses;
    }

    private boolean convert() {
        this.outputNode = this.neuron.node.get();
        this.initInputNodesAndComputeWeightSums();
        double remainingSum = 0.0;
        TreeSet<Synapse> tmp = new TreeSet<Synapse>(SYNAPSE_COMP);
        for (Synapse s : this.neuron.inputSynapses.values()) {
            if (s.isNegative() || s.key.isRecurrent) continue;
            remainingSum += s.weight;
            tmp.add(s);
        }
        Integer offset = null;
        Node requiredNode = null;
        boolean noFurtherRefinement = false;
        TreeSet<Synapse> reqSyns = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
        double sum = 0.0;
        if (remainingSum + this.neuron.posRecSum + this.neuron.biasSum > 0.0) {
            int i = 0;
            for (Synapse s : tmp) {
                boolean maxAndNodesReached;
                boolean isOptionalInput = sum + remainingSum - s.weight + this.neuron.posRecSum + this.neuron.biasSum > 0.0;
                boolean bl = maxAndNodesReached = i >= MAX_AND_NODE_SIZE;
                if (isOptionalInput || maxAndNodesReached) break;
                remainingSum -= s.weight;
                reqSyns.add(s);
                requiredNode = this.getNextLevelNode(offset, requiredNode, s);
                offset = Utils.nullSafeMin(s.key.relativeRid, offset);
                ++i;
                boolean sumOfSynapseWeightsAboveThreshold = (sum += s.weight) + this.neuron.posRecSum + this.neuron.biasSum > 0.0;
                if (!sumOfSynapseWeightsAboveThreshold) continue;
                noFurtherRefinement = true;
                break;
            }
            this.outputNode.removeParents(this.threadId, false);
            if (requiredNode != this.outputNode.requiredNode) {
                this.outputNode.requiredNode = requiredNode;
            }
            if (noFurtherRefinement || i == MAX_AND_NODE_SIZE) {
                this.outputNode.addInput(offset, this.threadId, requiredNode, false);
            } else {
                for (Synapse s : tmp) {
                    boolean belowThreshold;
                    boolean bl = belowThreshold = sum + s.weight + remainingSum + this.neuron.posRecSum + this.neuron.biasSum <= 0.0;
                    if (belowThreshold) break;
                    if (reqSyns.contains(s)) continue;
                    Node nln = this.getNextLevelNode(offset, requiredNode, s);
                    Integer nOffset = Utils.nullSafeMin(s.key.relativeRid, offset);
                    this.outputNode.addInput(nOffset, this.threadId, nln, false);
                    remainingSum -= s.weight;
                }
            }
        }
        for (Synapse s : this.modifiedSynapses) {
            if (!(s.weight + this.neuron.posRecSum + this.neuron.biasSum > 0.0)) continue;
            Node nln = s.inputNode.get();
            offset = s.key.relativeRid;
            this.outputNode.addInput(offset, this.threadId, nln, false);
        }
        return true;
    }

    private void initInputNodesAndComputeWeightSums() {
        double[][] sumDelta = new double[2][2];
        double maxRecurrentSumDelta = 0.0;
        this.neuron.biasSum = 0.0;
        for (Synapse s : this.modifiedSynapses) {
            if (s.toBeDeleted) {
                s.weightDelta = -s.weight;
                s.biasDelta = -s.bias;
            }
            INeuron in = (INeuron)s.input.get();
            in.lock.acquireWriteLock();
            if (s.inputNode == null) {
                InputNode iNode = InputNode.add(this.model, s.key.createInputNodeKey(), (INeuron)s.input.get());
                iNode.setModified();
                iNode.setSynapse(s);
                s.inputNode = iNode.provider;
            }
            if (s.key.isRecurrent) {
                maxRecurrentSumDelta += Math.abs(s.weight + s.weightDelta) - Math.abs(s.weight);
            }
            double[] dArray = sumDelta[s.key.isRecurrent ? 1 : 0];
            int n = s.isNegative() ? 1 : 0;
            dArray[n] = dArray[n] - s.weight;
            double[] dArray2 = sumDelta[s.key.isRecurrent ? 1 : 0];
            int n2 = s.weight + s.weightDelta <= 0.0 ? 1 : 0;
            dArray2[n2] = dArray2[n2] + (s.weight + s.weightDelta);
            s.weight += s.weightDelta;
            s.weightDelta = 0.0;
            s.bias += s.biasDelta;
            s.biasDelta = 0.0;
            this.neuron.biasSum += s.bias;
            in.lock.releaseWriteLock();
            if (!s.toBeDeleted) continue;
            s.unlink();
        }
        this.neuron.bias += this.neuron.biasDelta;
        this.neuron.biasDelta = 0.0;
        this.neuron.biasSum += this.neuron.bias;
        this.neuron.biasSum = Math.min(this.neuron.biasSum, 0.0);
        assert (Double.isFinite(this.neuron.biasSum));
        this.neuron.maxRecurrentSum += maxRecurrentSumDelta;
        this.neuron.posDirSum += sumDelta[0][0];
        this.neuron.negDirSum += sumDelta[0][1];
        this.neuron.negRecSum += sumDelta[1][1];
        this.neuron.posRecSum += sumDelta[1][0];
        this.neuron.setModified();
    }

    private Node getNextLevelNode(Integer offset, Node requiredNode, Synapse s) {
        Node nln = requiredNode == null ? (Node)s.inputNode.get() : AndNode.createNextLevelNode(this.model, this.threadId, requiredNode, new AndNode.Refinement(s.key.relativeRid, offset, s.inputNode), null);
        return nln;
    }
}

