package org.aika.network.neuron.simple.lattice;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Option;
import org.aika.corpus.Range;
import org.aika.network.Iteration;
import org.aika.network.Model;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Input;
import org.aika.network.neuron.Neuron;
import org.aika.network.neuron.recurrent.RecurrentNeuron;

import java.util.Collection;
import java.util.Collections;
import java.util.Set;
import java.util.TreeSet;

/**
 *
 * @author Lukas Molzberger
 */
public abstract class InputNode extends LatticeNode implements Input {

    public Integer rid;
    public Neuron inputNeuron;


    public InputNode(Model m, Integer rid) {
        super(m, 1);
        this.rid = rid;
    }


    public static InputNode add(Model m, Integer rid, Neuron input, boolean isNeg) {
        InputKey ik = new InputKey(!isNeg ? 0 : 1, rid);
        InputNode in = (InputNode) (input != null ? input.outputNodes.get(ik) : null);
        if(in != null) {
            return in;
        }
        if(!isNeg) {
            in = new PositiveInputNode(m, rid);
        } else {
            NegativeInputNode negIn = new NegativeInputNode(m, rid);
            m.negationNodes.add(negIn);
            in = negIn;
        }

        if(input != null) {
            in.inputNeuron = input;
            input.outputNodes.put(ik, in);
        }
        return in;
    }


    @Override
    public void initActivation(Iteration t, Activation act) {
    }


    @Override
    public void deleteActivation(Iteration t, Activation act) {
    }


    private Activation.Key computeActivationKey(Iteration t, Activation iAct) {
        int rid = inputNeuron instanceof RecurrentNeuron ? iAct.key.rid : 0;
        if(this.rid != null && this.rid != rid) return null;
        Option o = iAct.newOption != null ? Option.add(t.doc, true, iAct.key.o, iAct.newOption) : iAct.key.o;
        if(o == null) return null;
        return new Activation.Key(this, iAct.key.pos, rid, o, iAct.key.fired + 1);
    }


    @Override
    public void addActivation(Iteration t, Activation inputAct, Range addedRange) {
        Activation.Key ak = computeActivationKey(t, inputAct);

        if(ak != null) {
            addActivationAndPropagate(t, false, ak, addedRange, null, Collections.singleton(inputAct), Collections.singleton(inputAct));
        }
    }


    @Override
    public void removeActivation(Iteration t, Activation inputAct, Range removedRange) {
        Activation.Key ak = computeActivationKey(t, inputAct);

        if(ak != null) {
            removeActivationAndPropagate(t, false, Activation.get(this, null, removedRange, Range.Relation.CONTAINS, ak.o, Option.Relation.EQUALS, ak.fired, null, true), removedRange);
        }
    }


    public abstract int getSign();


    @Override
    public boolean isAllowedOption(Option n, Activation act, long v) {
        return false;
    }


    @Override
    protected void collectNodeAndRefinements(AndNode.Refinement newRef, Set<AndNode.Refinement> inputs) {
        inputs.add(new AndNode.Refinement(-newRef.rid, newRef.inferenceMode, this));
        inputs.add(newRef);
    }


    /**
     *
     * @param t
     * @param act
     * @param addedRange
     * @param removedConflict This parameter contains a removed conflict if it is not null. In this case only expand activations that contain this removed conflict.
     * @param train
     */
    @Override
    public void expandToNextLevel(Iteration t, Activation act, Range addedRange, Option removedConflict, boolean train) {
        // Check if the activation has been deleted in the meantime.
        if(act.isRemoved) {
            return;
        }

        for(AndNode.Refinement ref: getNodesFromActivations(act, act.key.pos.getOverlappingInputActivations())) {
            AndNode.processCandidate(t, this, ref.input, ref, act, addedRange, removedConflict, train);
        }

        if(removedConflict == null) {
            for(AndNode.Refinement ref: andChildren.keySet()) {
                if(ref.input instanceof NegativeInputNode) {
                    AndNode.processCandidate(t, this, ref.input, ref, act, addedRange, null, train);
                }
            }
        }

        OrNode.processCandidate(t, this, act, addedRange, removedConflict, train);
    }


    private Set<AndNode.Refinement> getNodesFromActivations(Activation firstAct, Collection<Activation> acts) {
        TreeSet<AndNode.Refinement> results = new TreeSet<>();
        for(Activation secondAct: acts) {
            if(firstAct != secondAct && !(secondAct.key.n instanceof NegativeInputNode)) {
                results.add(new AndNode.Refinement(secondAct.key.rid - firstAct.key.rid, false, (InputNode) secondAct.key.n));
                results.add(new AndNode.Refinement(secondAct.key.rid - firstAct.key.rid, true, (InputNode) secondAct.key.n));
            }
        }
        return results;
    }


    @Override
    public double computeSynapseWeightSum(Neuron n) {
        return n.bias + Math.abs(n.inputSynapses.get(inputNeuron).w);
    }


    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        if(this instanceof PositiveInputNode) {
            sb.append("P");
        } else {
            sb.append("N");
        }
        sb.append("[");
        if(inputNeuron != null) {
            sb.append(inputNeuron.id);
            if(inputNeuron.label != null) {
                sb.append(",");
                sb.append(inputNeuron.label);
            }
        }
        sb.append("]");
        return sb.toString();
    }

}
