/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.structure;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.flat.FlatLayer;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.FlatNetworkRBF;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ObjectPair;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.ContextLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.logic.FeedforwardLogic;
import org.encog.neural.networks.logic.SimpleRecurrentLogic;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.structure.LayerComparator;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.structure.SynapseComparator;
import org.encog.neural.networks.structure.ValidateForFlat;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.util.obj.ReflectionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeuralStructure
implements Serializable {
    private static final long serialVersionUID = -2929683885395737817L;
    private static final transient Logger LOGGER = LoggerFactory.getLogger(NeuralStructure.class);
    private final List<Layer> layers = new ArrayList<Layer>();
    private final List<Synapse> synapses = new ArrayList<Synapse>();
    private final BasicNetwork network;
    private double connectionLimit;
    private boolean connectionLimited;
    private int nextID = 1;
    private transient FlatNetwork flat;
    private transient FlatUpdateNeeded flatUpdate;

    public NeuralStructure(BasicNetwork network) {
        this.network = network;
        this.flatUpdate = FlatUpdateNeeded.None;
    }

    public void assignID() {
        for (Layer layer : this.layers) {
            this.assignID(layer);
        }
        this.sort();
    }

    public void assignID(Layer layer) {
        if (layer.getID() == -1) {
            layer.setID(this.getNextID());
        }
    }

    public int calculateSize() {
        return NetworkCODEC.networkSize(this.network);
    }

    public boolean containsLayerType(Class<?> type) {
        for (Layer layer : this.layers) {
            if (!ReflectionUtil.isInstanceOf(layer.getClass(), type)) continue;
            return true;
        }
        return false;
    }

    private int countNonContext() {
        int result = 0;
        for (Layer layer : this.getLayers()) {
            if (layer.getClass() == ContextLayer.class) continue;
            ++result;
        }
        return result;
    }

    public void enforceLimit() {
        if (!this.connectionLimited) {
            return;
        }
        for (Synapse synapse : this.synapses) {
            Matrix matrix = synapse.getMatrix();
            if (matrix == null) continue;
            for (int row = 0; row < matrix.getRows(); ++row) {
                for (int col = 0; col < matrix.getCols(); ++col) {
                    double value = matrix.get(row, col);
                    if (!(Math.abs(value) < this.connectionLimit)) continue;
                    matrix.set(row, col, 0.0);
                }
            }
        }
    }

    private void finalizeLayers() {
        if (this.network.getLogic().getClass() == FeedforwardLogic.class || this.network.getLogic().getClass() == SimpleRecurrentLogic.class) {
            Layer inputLayer = this.network.getLayer("INPUT");
            inputLayer.setBiasWeights(null);
        }
        ArrayList<Layer> result = new ArrayList<Layer>();
        this.layers.clear();
        for (Layer layer : this.network.getLayerTags().values()) {
            this.getLayers(result, layer);
        }
        this.layers.addAll(result);
        for (Layer layer : this.layers) {
            if (layer.getID() < this.nextID) continue;
            this.nextID = layer.getID() + 1;
        }
        this.sort();
    }

    private void finalizeLimit() {
        String limit = this.network.getPropertyString("CONNECTION_LIMIT");
        if (limit != null) {
            try {
                this.connectionLimited = true;
                this.connectionLimit = Double.parseDouble(limit);
            }
            catch (NumberFormatException e) {
                throw new NeuralNetworkError("Invalid property(CONNECTION_LIMIT):" + limit);
            }
        } else {
            this.connectionLimited = false;
            this.connectionLimit = 0.0;
        }
    }

    public void finalizeStructure() {
        this.finalizeLayers();
        this.finalizeSynapses();
        this.finalizeLimit();
        Collections.sort(this.layers);
        this.assignID();
        this.network.getLogic().init(this.network);
        this.enforceLimit();
        this.flatten();
    }

    private void finalizeSynapses() {
        HashSet<Synapse> result = new HashSet<Synapse>();
        for (Layer layer : this.getLayers()) {
            for (Synapse synapse : layer.getNext()) {
                result.add(synapse);
            }
        }
        this.synapses.clear();
        this.synapses.addAll(result);
    }

    private double findNextBias(Layer layer) {
        Layer nextLayer;
        Synapse synapse;
        double bias = 0.0;
        if (layer.getNext().size() > 0 && (synapse = this.network.getStructure().findNextSynapseByLayerType(layer, BasicLayer.class)) != null && (nextLayer = synapse.getToLayer()).hasBias()) {
            bias = nextLayer.getBiasActivation();
        }
        return bias;
    }

    public Synapse findNextSynapseByLayerType(Layer layer, Class<? extends Layer> type) {
        for (Synapse synapse : layer.getNext()) {
            if (synapse.getToLayer().getClass() != type) continue;
            return synapse;
        }
        return null;
    }

    public Synapse findPreviousSynapseByLayerType(Layer layer, Class<? extends Layer> type) {
        for (Synapse synapse : this.getPreviousSynapses(layer)) {
            if (synapse.getFromLayer().getClass() != type) continue;
            return synapse;
        }
        return null;
    }

    public Synapse findSynapse(Layer fromLayer, Layer toLayer, boolean required) {
        Synapse result = null;
        for (Synapse synapse : this.getSynapses()) {
            if (synapse.getFromLayer() != fromLayer || synapse.getToLayer() != toLayer) continue;
            result = synapse;
            break;
        }
        if (required && result == null) {
            String str = "This operation requires a network with a synapse between the " + this.nameLayer(fromLayer) + " layer to the " + this.nameLayer(toLayer) + " layer.";
            if (LOGGER.isErrorEnabled()) {
                LOGGER.error(str);
            }
            throw new NeuralNetworkError(str);
        }
        return result;
    }

    public void flatten() {
        boolean isRBF = false;
        HashMap<Layer, FlatLayer> regular2flat = new HashMap<Layer, FlatLayer>();
        HashMap<FlatLayer, Layer> flat2regular = new HashMap<FlatLayer, Layer>();
        ArrayList<ObjectPair<Layer, Layer>> contexts = new ArrayList<ObjectPair<Layer, Layer>>();
        this.flat = null;
        ValidateForFlat val = new ValidateForFlat();
        if (val.isValid(this.network) == null) {
            Synapse synapse;
            if (this.layers.size() == 3 && this.layers.get(1) instanceof RadialBasisFunctionLayer) {
                RadialBasisFunctionLayer rbf = (RadialBasisFunctionLayer)this.layers.get(1);
                for (Layer layer : this.layers) {
                    if (!layer.hasBias()) continue;
                    throw new NeuralNetworkError("Bias cannot be used with an RBF neural network.");
                }
                this.flat = new FlatNetworkRBF(this.network.getInputCount(), rbf.getNeuronCount(), this.network.getOutputCount(), rbf.getRadialBasisFunction());
                this.flattenWeights();
                this.flatUpdate = FlatUpdateNeeded.None;
                return;
            }
            int flatLayerCount = this.countNonContext();
            FlatLayer[] flatLayers = new FlatLayer[flatLayerCount];
            int index = flatLayers.length - 1;
            for (Layer layer : this.layers) {
                ActivationFunction activationType;
                if (layer instanceof ContextLayer) {
                    Synapse inboundSynapse = this.network.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
                    Synapse outboundSynapse = this.network.getStructure().findNextSynapseByLayerType(layer, BasicLayer.class);
                    if (inboundSynapse == null) {
                        throw new NeuralNetworkError("Context layer must be connected to by one BasicLayer.");
                    }
                    if (outboundSynapse == null) {
                        throw new NeuralNetworkError("Context layer must connect to by one BasicLayer.");
                    }
                    Layer inbound = inboundSynapse.getFromLayer();
                    Layer outbound = outboundSynapse.getToLayer();
                    contexts.add(new ObjectPair<Layer, Layer>(inbound, outbound));
                    continue;
                }
                double bias = this.findNextBias(layer);
                double[] params = new double[1];
                if (layer.getActivationFunction() == null) {
                    activationType = new ActivationLinear();
                    params = new double[]{1.0};
                } else {
                    activationType = layer.getActivationFunction();
                    params = layer.getActivationFunction().getParams();
                }
                FlatLayer flatLayer = new FlatLayer(activationType, layer.getNeuronCount(), bias, params);
                regular2flat.put(layer, flatLayer);
                flat2regular.put(flatLayer, layer);
                flatLayers[index--] = flatLayer;
            }
            for (ObjectPair objectPair : contexts) {
                Layer layer = (Layer)objectPair.getB();
                synapse = this.network.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
                FlatLayer from = (FlatLayer)regular2flat.get(objectPair.getA());
                FlatLayer to = (FlatLayer)regular2flat.get(synapse.getFromLayer());
                to.setContextFedBy(from);
            }
            this.flat = new FlatNetwork(flatLayers);
            for (int i = 0; i < flatLayerCount; ++i) {
                FlatLayer flatLayer = flatLayers[i].getContextFedBy();
                if (flatLayer == null) continue;
                Layer fedBy2 = (Layer)flat2regular.get(flatLayers[i + 1]);
                synapse = this.findPreviousSynapseByLayerType(fedBy2, ContextLayer.class);
                if (synapse == null) {
                    throw new NeuralNetworkError("Can't find parent synapse to context layer.");
                }
                ContextLayer context = (ContextLayer)synapse.getFromLayer();
                int fedByIndex = -1;
                for (int j = 0; j < flatLayerCount; ++j) {
                    if (flatLayers[j] != flatLayer) continue;
                    fedByIndex = j;
                    break;
                }
                if (fedByIndex == -1) {
                    throw new NeuralNetworkError("Can't find layer feeding context.");
                }
                context.setFlatContextIndex(this.flat.getContextTargetOffset()[fedByIndex]);
            }
            this.flattenWeights();
            this.flatUpdate = FlatUpdateNeeded.None;
        } else {
            this.flatUpdate = FlatUpdateNeeded.Never;
        }
    }

    public void flattenWeights() {
        if (this.flat != null) {
            this.flatUpdate = FlatUpdateNeeded.Flatten;
            double[] targetWeights = this.flat.getWeights();
            double[] sourceWeights = NetworkCODEC.networkToArray(this.network);
            EngineArray.arrayCopy(sourceWeights, targetWeights);
            this.flatUpdate = FlatUpdateNeeded.None;
            for (Layer layer : this.layers) {
                ContextLayer context;
                if (!(layer instanceof ContextLayer) || (context = (ContextLayer)layer).getFlatContextIndex() == -1) continue;
                EngineArray.arrayCopy(context.getContext().getData(), 0, this.flat.getLayerOutput(), context.getFlatContextIndex(), context.getContext().size());
            }
            if (this.connectionLimited) {
                this.flat.setConnectionLimit(this.connectionLimit);
            } else {
                this.flat.clearConnectionLimit();
            }
        }
    }

    public double getConnectionLimit() {
        return this.connectionLimit;
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public FlatUpdateNeeded getFlatUpdate() {
        return this.flatUpdate;
    }

    public List<Layer> getLayers() {
        return this.layers;
    }

    private void getLayers(List<Layer> result, Layer layer) {
        if (!result.contains(layer)) {
            result.add(layer);
        }
        for (Synapse synapse : layer.getNext()) {
            Layer nextLayer = synapse.getToLayer();
            if (result.contains(nextLayer)) continue;
            this.getLayers(result, nextLayer);
        }
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public int getNextID() {
        return this.nextID++;
    }

    public Collection<Layer> getPreviousLayers(Layer targetLayer) {
        HashSet<Layer> result = new HashSet<Layer>();
        for (Layer layer : this.getLayers()) {
            for (Synapse synapse : layer.getNext()) {
                if (synapse.getToLayer() != targetLayer) continue;
                result.add(synapse.getFromLayer());
            }
        }
        return result;
    }

    public List<Synapse> getPreviousSynapses(Layer targetLayer) {
        ArrayList<Synapse> result = new ArrayList<Synapse>();
        for (Synapse synapse : this.synapses) {
            if (synapse.getToLayer() != targetLayer || result.contains(synapse)) continue;
            result.add(synapse);
        }
        return result;
    }

    public List<Synapse> getSynapses() {
        return this.synapses;
    }

    public boolean isConnectionLimited() {
        return this.connectionLimited;
    }

    public boolean isRecurrent() {
        for (Layer layer : this.getLayers()) {
            if (!(layer instanceof ContextLayer)) continue;
            return true;
        }
        return false;
    }

    public List<String> nameLayer(Layer layer) {
        ArrayList<String> result = new ArrayList<String>();
        for (Map.Entry<String, Layer> entry : this.network.getLayerTags().entrySet()) {
            if (entry.getValue() != layer) continue;
            result.add(entry.getKey());
        }
        return result;
    }

    public void setFlatUpdate(FlatUpdateNeeded flatUpdate) {
        this.flatUpdate = flatUpdate;
    }

    public void sort() {
        Collections.sort(this.layers, new LayerComparator(this));
        Collections.sort(this.synapses, new SynapseComparator(this));
    }

    public void unflattenWeights() {
        if (this.flat != null) {
            double[] sourceWeights = this.flat.getWeights();
            NetworkCODEC.arrayToNetwork(sourceWeights, this.network);
            this.flatUpdate = FlatUpdateNeeded.None;
            for (Layer layer : this.layers) {
                ContextLayer context;
                if (!(layer instanceof ContextLayer) || (context = (ContextLayer)layer).getFlatContextIndex() == -1) continue;
                EngineArray.arrayCopy(this.flat.getLayerOutput(), context.getFlatContextIndex(), context.getContext().getData(), 0, context.getContext().size());
            }
        }
    }

    public void updateFlatNetwork() {
        if (this.flatUpdate == null) {
            this.flattenWeights();
            this.flatUpdate = FlatUpdateNeeded.None;
        }
        switch (this.flatUpdate) {
            case Flatten: {
                this.flattenWeights();
                break;
            }
            case Unflatten: {
                this.unflattenWeights();
                break;
            }
            case None: 
            case Never: {
                return;
            }
        }
        this.flatUpdate = FlatUpdateNeeded.None;
    }
}

