/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.prune;

import java.util.Collection;
import java.util.List;
import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.matrices.MatrixMath;
import org.encog.mathutil.randomize.Distort;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.synapse.Synapse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PruneSelective {
    private final BasicNetwork network;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    public PruneSelective(BasicNetwork network) {
        this.network = network;
    }

    public void changeNeuronCount(Layer layer, int neuronCount) {
        if (neuronCount == 0) {
            throw new NeuralNetworkError("Can't decrease to zero neurons.");
        }
        if (neuronCount == layer.getNeuronCount()) {
            return;
        }
        if (neuronCount > layer.getNeuronCount()) {
            this.increaseNeuronCount(layer, neuronCount);
        } else {
            this.decreaseNeuronCount(layer, neuronCount);
        }
    }

    private void decreaseNeuronCount(Layer layer, int neuronCount) {
        int lostNeuronCount = layer.getNeuronCount() - neuronCount;
        int[] lostNeuron = this.findWeakestNeurons(layer, lostNeuronCount);
        for (int i = 0; i < lostNeuronCount; ++i) {
            this.prune(layer, lostNeuron[i] - i);
        }
    }

    public double determineNeuronSignificance(Layer layer, int neuron) {
        double result = 0.0;
        if (layer.hasBias()) {
            result += layer.getBiasWeight(neuron);
        }
        for (Synapse synapse : layer.getNext()) {
            for (int i = 0; i < synapse.getToNeuronCount(); ++i) {
                result += synapse.getMatrix().get(neuron, i);
            }
        }
        List<Synapse> inboundSynapses = this.network.getStructure().getPreviousSynapses(layer);
        for (Synapse synapse : inboundSynapses) {
            if (synapse.getMatrix() == null) continue;
            for (int i = 0; i < synapse.getFromNeuronCount(); ++i) {
                result += synapse.getMatrix().get(i, neuron);
            }
        }
        return Math.abs(result);
    }

    private int[] findWeakestNeurons(Layer layer, int count) {
        int i;
        double[] lostNeuronSignificance = new double[count];
        int[] lostNeuron = new int[count];
        for (i = 0; i < count; ++i) {
            lostNeuron[i] = i;
            lostNeuronSignificance[i] = this.determineNeuronSignificance(layer, i);
        }
        block1: for (i = count; i < layer.getNeuronCount(); ++i) {
            double significance = this.determineNeuronSignificance(layer, i);
            for (int j = 0; j < count; ++j) {
                if (!(lostNeuronSignificance[j] > significance)) continue;
                lostNeuron[j] = i;
                lostNeuronSignificance[j] = significance;
                continue block1;
            }
        }
        return lostNeuron;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void increaseNeuronCount(Layer layer, int neuronCount) {
        double[] newBias = new double[neuronCount];
        if (layer.hasBias()) {
            for (int i = 0; i < layer.getNeuronCount(); ++i) {
                newBias[i] = layer.getBiasWeight(i);
            }
            layer.setBiasWeights(newBias);
        }
        for (Synapse synapse : layer.getNext()) {
            Matrix newMatrix = new Matrix(neuronCount, synapse.getToNeuronCount());
            for (int row = 0; row < layer.getNeuronCount(); ++row) {
                for (int col = 0; col < synapse.getToNeuronCount(); ++col) {
                    newMatrix.set(row, col, synapse.getMatrix().get(row, col));
                }
            }
            synapse.setMatrix(newMatrix);
        }
        List<Synapse> inboundSynapses = this.network.getStructure().getPreviousSynapses(layer);
        for (Synapse synapse : inboundSynapses) {
            if (synapse.getMatrix() == null) continue;
            Matrix newMatrix = new Matrix(synapse.getFromNeuronCount(), neuronCount);
            for (int row = 0; row < synapse.getFromNeuronCount(); ++row) {
                for (int col = 0; col < synapse.getToNeuronCount(); ++col) {
                    newMatrix.set(row, col, synapse.getMatrix().get(row, col));
                }
            }
            synapse.setMatrix(newMatrix);
        }
        if (layer.hasBias()) {
            double[] newBias2 = new double[neuronCount];
            for (int i = 0; i < layer.getNeuronCount(); ++i) {
                newBias2[i] = layer.getBiasWeight(i);
            }
            layer.setBiasWeights(newBias2);
        }
        layer.setNeuronCount(neuronCount);
    }

    public void prune(Layer targetLayer, int neuron) {
        for (Synapse synapse : targetLayer.getNext()) {
            synapse.setMatrix(MatrixMath.deleteRow(synapse.getMatrix(), neuron));
        }
        Collection<Layer> previous = this.network.getStructure().getPreviousLayers(targetLayer);
        for (Layer prevLayer : previous) {
            if (previous == null) continue;
            for (Synapse synapse : prevLayer.getNext()) {
                if (synapse.getMatrix() == null) continue;
                synapse.setMatrix(MatrixMath.deleteCol(synapse.getMatrix(), neuron));
            }
        }
        if (targetLayer.hasBias()) {
            double[] newBias = new double[targetLayer.getNeuronCount() - 1];
            int targetIndex = 0;
            for (int i = 0; i < targetLayer.getNeuronCount(); ++i) {
                if (i == neuron) continue;
                newBias[targetIndex++] = targetLayer.getBiasWeight(i);
            }
            targetLayer.setBiasWeights(newBias);
        }
        targetLayer.setNeuronCount(targetLayer.getNeuronCount() - 1);
    }

    public void stimulateNeuron(double percent, Layer layer, int neuron) {
        Distort d = new Distort(percent);
        if (layer.hasBias()) {
            layer.setBiasWeight(neuron, d.randomize(layer.getBiasWeight(neuron)));
        }
        for (Synapse synapse : layer.getNext()) {
            for (int i = 0; i < synapse.getToNeuronCount(); ++i) {
                double v = synapse.getMatrix().get(neuron, i);
                synapse.getMatrix().set(neuron, i, d.randomize(v));
            }
        }
        List<Synapse> inboundSynapses = this.network.getStructure().getPreviousSynapses(layer);
        for (Synapse synapse : inboundSynapses) {
            for (int i = 0; i < synapse.getFromNeuronCount(); ++i) {
                double v = synapse.getMatrix().get(i, neuron);
                synapse.getMatrix().set(i, neuron, d.randomize(v));
            }
        }
    }

    public void stimulateWeakNeurons(Layer layer, int count, double percent) {
        int[] weak;
        for (int element : weak = this.findWeakestNeurons(layer, count)) {
            this.stimulateNeuron(percent, layer, element);
        }
    }
}

