/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.structure;

import java.util.Arrays;
import java.util.List;
import org.encog.engine.util.EngineArray;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.ContextLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.synapse.Synapse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class NetworkCODEC {
    private static final Logger LOGGER = LoggerFactory.getLogger(NetworkCODEC.class);

    public static void arrayToNetwork(double[] array, BasicNetwork network) {
        int index = 0;
        for (Layer layer : network.getStructure().getLayers()) {
            index = NetworkCODEC.processLayer(network, layer, array, index);
        }
        network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public static boolean equals(BasicNetwork network1, BasicNetwork network2, int precision) {
        double[] array2;
        double[] array1 = NetworkCODEC.networkToArray(network1);
        if (array1.length != (array2 = NetworkCODEC.networkToArray(network2)).length) {
            return false;
        }
        double test = Math.pow(10.0, precision);
        if (Double.isInfinite(test) || test > 9.223372036854776E18) {
            String str = "Precision of " + precision + " decimal places is not supported.";
            if (LOGGER.isErrorEnabled()) {
                LOGGER.error(str);
            }
            throw new NeuralNetworkError(str);
        }
        for (int i = 0; i < array1.length; ++i) {
            long l1 = (long)(array1[i] * test);
            long l2 = (long)(array2[i] * test);
            if (l1 == l2) continue;
            return false;
        }
        return true;
    }

    public static boolean equals(BasicNetwork network1, BasicNetwork network2) {
        double[] array2;
        double[] array1 = NetworkCODEC.networkToArray(network1);
        if (array1.length != (array2 = NetworkCODEC.networkToArray(network2)).length) {
            return false;
        }
        return Arrays.equals(array1, array2);
    }

    public static int networkSize(BasicNetwork network) {
        if (network.getStructure().getFlat() != null && (network.getStructure().getFlatUpdate() == FlatUpdateNeeded.None || network.getStructure().getFlatUpdate() == FlatUpdateNeeded.Unflatten)) {
            return network.getStructure().getFlat().getWeights().length;
        }
        int index = 0;
        for (Layer layer : network.getStructure().getLayers()) {
            Synapse synapse = network.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
            Synapse contextSynapse = network.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
            List<Synapse> list = network.getStructure().getPreviousSynapses(layer);
            if (synapse == null && contextSynapse == null && list.size() > 0) {
                synapse = list.get(0);
            }
            if (synapse == null || synapse.getMatrix() == null) continue;
            for (int x = 0; x < synapse.getToNeuronCount(); ++x) {
                index += synapse.getFromNeuronCount();
                if (synapse.getToLayer().hasBias()) {
                    ++index;
                }
                if (contextSynapse == null) continue;
                index += contextSynapse.getFromNeuronCount();
            }
        }
        return index;
    }

    public static double[] networkToArray(BasicNetwork network) {
        int size = NetworkCODEC.networkSize(network);
        if (network.getStructure().getFlat() != null && (network.getStructure().getFlatUpdate() == FlatUpdateNeeded.None || network.getStructure().getFlatUpdate() == FlatUpdateNeeded.Unflatten)) {
            return EngineArray.arrayCopy(network.getStructure().getFlat().getWeights());
        }
        double[] result = new double[size];
        int index = 0;
        for (Layer layer : network.getStructure().getLayers()) {
            Synapse synapse = network.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
            Synapse contextSynapse = network.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
            List<Synapse> list = network.getStructure().getPreviousSynapses(layer);
            if (synapse == null && contextSynapse == null && list.size() > 0) {
                synapse = list.get(0);
            }
            if (synapse == null || synapse.getMatrix() == null) continue;
            for (int x = 0; x < synapse.getToNeuronCount(); ++x) {
                for (int y = 0; y < synapse.getFromNeuronCount(); ++y) {
                    result[index++] = synapse.getMatrix().get(y, x);
                }
                if (synapse.getToLayer().hasBias()) {
                    result[index++] = synapse.getToLayer().getBiasWeights()[x];
                }
                if (contextSynapse == null) continue;
                for (int z = 0; z < contextSynapse.getFromNeuronCount(); ++z) {
                    result[index++] = contextSynapse.getMatrix().get(z, x);
                }
            }
        }
        return result;
    }

    private static int processLayer(BasicNetwork network, Layer layer, double[] array, int index) {
        int result = index;
        Synapse synapse = network.getStructure().findPreviousSynapseByLayerType(layer, BasicLayer.class);
        Synapse contextSynapse = network.getStructure().findPreviousSynapseByLayerType(layer, ContextLayer.class);
        List<Synapse> list = network.getStructure().getPreviousSynapses(layer);
        if (synapse == null && contextSynapse == null && list.size() > 0) {
            synapse = list.get(0);
        }
        if (synapse != null && synapse.getMatrix() != null) {
            for (int x = 0; x < synapse.getToNeuronCount(); ++x) {
                for (int y = 0; y < synapse.getFromNeuronCount(); ++y) {
                    synapse.getMatrix().set(y, x, array[result++]);
                }
                if (synapse.getToLayer().hasBias()) {
                    synapse.getToLayer().getBiasWeights()[x] = array[result++];
                }
                if (contextSynapse == null) continue;
                for (int z = 0; z < contextSynapse.getFromNeuronCount(); ++z) {
                    double value = array[result++];
                    double oldValue = contextSynapse.getMatrix().get(z, x);
                    if (Math.abs(oldValue) < network.getStructure().getConnectionLimit()) {
                        value = 0.0;
                    }
                    contextSynapse.getMatrix().set(z, x, value);
                }
            }
        }
        return result;
    }

    private NetworkCODEC() {
    }
}

