package org.aika.network;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Document;
import org.aika.corpus.Option;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Neuron;
import org.aika.network.neuron.Node;
import org.aika.network.neuron.Synapse;
import org.aika.network.neuron.recurrent.ClockTerminationNode;
import org.aika.network.neuron.recurrent.RecurrentNeuron;
import org.aika.network.neuron.recurrent.RecurrentNode;
import org.aika.network.neuron.simple.SimpleNeuron;
import org.aika.network.neuron.simple.SimpleNeuron.Input;
import org.aika.network.neuron.simple.lattice.AndNode;
import org.aika.network.neuron.simple.lattice.NegativeInputNode;
import org.aika.utils.StringUtils;

import java.util.*;


/**
 *
 * @author Lukas Molzberger
 */
public class Model {

    public Set<Integer> trainingInterval;
    public Map<String, SimpleNeuron> labeledNeurons = new LinkedHashMap<>();
    public List<NegativeInputNode> negationNodes = new ArrayList<>();
    public List<ClockTerminationNode> clockTerminationNodes = new ArrayList<>();
    public Set<Neuron> publishedNeurons = new TreeSet<>();

    public Set<Node> allNodes = new TreeSet<>();

    public int numberOfPositions;

    public static Comparator<Activation> ACTIVATIONS_OUTPUT_COMPARATOR = new Comparator<Activation>() {
        @Override
        public int compare(Activation act1, Activation act2) {
            int r = act1.key.pos.compareTo(act2.key.pos);
            if(r != 0) return r;
            r = act1.key.o.compareTo(act2.key.o);
            if(r != 0) return r;
            r = Integer.compare(act1.key.fired, act2.key.fired);
            if(r != 0) return r;
            r = Integer.compare(act1.key.n.id, act2.key.n.id);
            return r;
        }
    };

    public Model() {
        trainingInterval = new HashSet<>();
        for(int i = 0; i < 10000; i++) {
            double x = i;
            double y = Math.pow(x, 1.3);
            trainingInterval.add((int) Math.floor(y));

            if(i < 10) {
                trainingInterval.add(i);
            }
        }
    }


    public Iteration startIteration(Document doc) {
        Iteration t = new Iteration(doc, this);

        t.changeNumberOfPositions(1);
        ClockTerminationNode.addInitialActivations(t);

        return t;
    }


    public void dump() {
        System.out.println();
        System.out.println("Network Weights:");
        System.out.println(networkWeightsToString());

        System.out.println();
        System.out.println();
        System.out.println("Pattern Lattice:");
        System.out.println(patternLatticeToString());

        System.out.println();

        System.out.println("Publication Candidates:");
        List<AndNode> results = new ArrayList<>();
        for(Node n: allNodes) {
            if(n instanceof AndNode) {
                AndNode an = (AndNode) n;
                if (an.weight > 0.99) {
                    results.add(an);
                }
            }
        }

        for(int i = 0; i < results.size(); i++) {
            AndNode n = results.get(i);
            String syl = StringUtils.extractSyllable(n);
            System.out.println(i + " \"" + syl + "\"  " + n.toString() + " ShouldBePublished:" + n.shouldBePublished + (!n.shouldBePublished && n.directSignificantAncestor != null ? " (" + " \"" + StringUtils.extractSyllable(n.directSignificantAncestor) + "\"  " + n.directSignificantAncestor.toString() + ")" : ""));
        }

        System.out.println();
    }


    public static boolean hasSignificantChildren(AndNode n) {
        for(AndNode cn: n.andChildren.values()) {
            if(cn.weight > 0.99) return true;
            else {
                if(hasSignificantChildren(cn)) return true;
            }
        }

        return false;
    }


    public void reset() {
        labeledNeurons = new LinkedHashMap<>();
        negationNodes = new ArrayList<>();
        publishedNeurons = new HashSet<>();
    }


    public void resetFrequency() {
        for(Node n: allNodes) {
            n.frequency = 0;
        }
    }


    /**
     *
     * @param doc
     * @param allOrOnlySelectedOptions true: all, false: only selected options
     * @return
     */
    public String networkStateToString(Document doc, boolean allOrOnlySelectedOptions, boolean withWeights) {
        Set<Activation> acts = new TreeSet<>(ACTIVATIONS_OUTPUT_COMPARATOR);

        for(Neuron n: publishedNeurons) {
            acts.addAll(Activation.select(n.node, null, null, null, allOrOnlySelectedOptions ? null : doc.selectedOption, Option.Relation.CONTAINED_IN, null, null, null, false));
        }
        StringBuilder sb = new StringBuilder();
        double weightSum = 0.0;
        for(Activation act: acts) {
            if(!allOrOnlySelectedOptions && act.newOption != null && !doc.selectedOption.contains(act.newOption)) {
                continue;
            }

            if(act.key.n.neuron != null && "SPACE".equals(act.key.n.neuron.label)) continue;

            sb.append(act.key.pos);
            sb.append(" - ");
            if (allOrOnlySelectedOptions) {
                sb.append(act.key.o.toString());
                if(act.newOption != null) {
                    sb.append(act.newOption.toString());
                }
                sb.append(" - ");
            }
            sb.append(act.key.n.toString());
            sb.append(" - Rid:");
            sb.append(act.key.rid);
            sb.append(" - Fired:");
            sb.append(act.key.fired);
            if(withWeights) {
                sb.append(" - W:");
                sb.append(act.weight);
            }
            weightSum += act.weight;
            sb.append("\n");
        }
        sb.append("\nWeightSum:" + weightSum + "\n");
        return sb.toString();
    }


    public String networkWeightsToString() {
        StringBuilder sb = new StringBuilder();
        for(Neuron n: publishedNeurons) {
            if(n.node.frequency > 0) {
                sb.append(n.toStringWithSynapses());
                sb.append("\n");
            }
        }
        return sb.toString();
    }


    public String patternLatticeToString() {
        StringBuilder sb = new StringBuilder();

        for(Node p: allNodes) {
            sb.append(p.toString());
            sb.append("\n");
        }
        return sb.toString();
    }


    public SimpleNeuron createOrLookupInputSignal(String label) {
        return createOrLookupInputSignal(label, false);
    }


    public SimpleNeuron createOrLookupInputSignal(String label, boolean isBlocked) {
        SimpleNeuron n = labeledNeurons.get(label);
        if(n == null) {
            n = SimpleNeuron.create(this, new SimpleNeuron(label, isBlocked), -0.5, new TreeSet<>(), true, false);
            labeledNeurons.put(label, n);
        }
        return n;
    }


    public SimpleNeuron createAndNeuron(SimpleNeuron n, boolean inferenceMode, Input... inputs) {
        return createAndNeuron(n, inferenceMode, new TreeSet<Input>(Arrays.asList(inputs)));
    }


    public SimpleNeuron createAndNeuron(SimpleNeuron n, boolean inferenceMode, Set<Input> inputs) {
        n.m = this;
        Set<Synapse> is = new TreeSet<>();

        double bias = 0.5;
        for(Input ni: inputs) {
            Synapse s = new Synapse(ni.inputNeuron, ni.rid, ni.relative);
            if(!ni.isOptional) {
                s.w = ni.isNeg ? -1.0f : 1.0f;
                bias -= 1;
            }
            is.add(s);
        }

        return SimpleNeuron.create(this, n, bias, is, true, inferenceMode);
    }


    public SimpleNeuron createOrNeuron(SimpleNeuron n, Input... inputs) {
        return createOrNeuron(n, new TreeSet<>(Arrays.asList(inputs)));
    }


    public SimpleNeuron createOrNeuron(SimpleNeuron n, Set<Input> inputs) {
        n.m = this;
        n.isPredefined = true;

        Set<Synapse> is = new TreeSet<>();

        double bias = -0.5;
        for(Input ni: inputs) {
            Synapse s = new Synapse(ni.inputNeuron, ni.rid, ni.relative);
            s.w = ni.isNeg ? -1.0f : 1.0f;
            is.add(s);
        }

        SimpleNeuron.create(this, n, bias, is, true, false);
        return n;
    }


    public RecurrentNeuron createRecurrentNeuron(RecurrentNeuron n, Neuron inputSignal, Neuron clockSignal, Neuron terminationSignal, boolean direction, int maxLength) {
        n.m = this;
        n.isPredefined = true;


        Synapse is = null;
        if(inputSignal != null) {
            is = new Synapse(inputSignal, RecurrentNode.RecurrentType.INPUT_SIGNAL);
            is.w = 1.0f;
        }

        Synapse csSynapse = null;
        if(clockSignal != null) {
            csSynapse = new Synapse(clockSignal, RecurrentNode.RecurrentType.CLOCK_SIGNAL);
            csSynapse.w = 1.0f;
        }

        Synapse tsSynapse = null;
        if(terminationSignal != null) {
            tsSynapse = new Synapse(terminationSignal, RecurrentNode.RecurrentType.TERMINATION_SIGNAL);
            tsSynapse.w = 1.0f;
        }

        RecurrentNeuron.create(this, n, is, csSynapse, tsSynapse, direction, maxLength, true);
        return n;
    }

}
