/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.competitive;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.encog.engine.util.Format;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.competitive.BestMatchingUnit;
import org.encog.neural.networks.training.competitive.neighborhood.NeighborhoodFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CompetitiveTraining
extends BasicTraining
implements LearningRate {
    private final NeighborhoodFunction neighborhood;
    private double learningRate;
    private final BasicNetwork network;
    private final Layer inputLayer;
    private final Layer outputLayer;
    private final Collection<Synapse> synapses;
    private final int inputNeuronCount;
    private final int outputNeuronCount;
    private final BestMatchingUnit bmuUtil;
    private final Map<Synapse, Matrix> correctionMatrix = new HashMap<Synapse, Matrix>();
    private boolean forceWinner;
    private double startRate;
    private double endRate;
    private double startRadius;
    private double endRadius;
    private double autoDecayRate;
    private double autoDecayRadius;
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private double radius;

    public CompetitiveTraining(BasicNetwork network, double learningRate, NeuralDataSet training, NeighborhoodFunction neighborhood) {
        this.neighborhood = neighborhood;
        this.setTraining(training);
        this.learningRate = learningRate;
        this.network = network;
        this.inputLayer = network.getLayer("INPUT");
        this.outputLayer = network.getLayer("OUTPUT");
        this.synapses = network.getStructure().getPreviousSynapses(this.outputLayer);
        this.inputNeuronCount = this.inputLayer.getNeuronCount();
        this.outputNeuronCount = this.outputLayer.getNeuronCount();
        this.forceWinner = false;
        this.setError(0.0);
        for (Synapse synapse : this.synapses) {
            Matrix matrix = new Matrix(synapse.getMatrix().getRows(), synapse.getMatrix().getCols());
            this.correctionMatrix.put(synapse, matrix);
        }
        this.bmuUtil = new BestMatchingUnit(this);
    }

    private void applyCorrection() {
        for (Map.Entry<Synapse, Matrix> entry : this.correctionMatrix.entrySet()) {
            entry.getKey().getMatrix().set(entry.getValue());
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }

    public void autoDecay() {
        if (this.radius > this.endRadius) {
            this.radius += this.autoDecayRadius;
        }
        if (this.learningRate > this.endRate) {
            this.learningRate += this.autoDecayRate;
        }
        this.getNeighborhood().setRadius(this.radius);
    }

    private void copyInputPattern(Synapse synapse, int outputNeuron, NeuralData input) {
        for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; ++inputNeuron) {
            synapse.getMatrix().set(inputNeuron, outputNeuron, input.getData(inputNeuron));
        }
    }

    public void decay(double d) {
        this.radius *= 1.0 - d;
        this.learningRate *= 1.0 - d;
    }

    public void decay(double decayRate, double decayRadius) {
        this.radius *= 1.0 - decayRadius;
        this.learningRate *= 1.0 - decayRate;
        this.getNeighborhood().setRadius(this.radius);
    }

    private double determineNewWeight(double weight, double input, int currentNeuron, int bmu) {
        double newWeight = weight + this.neighborhood.function(currentNeuron, bmu) * this.learningRate * (input - weight);
        return newWeight;
    }

    private boolean forceWinners(Synapse synapse, int[] won, NeuralData leastRepresented) {
        double maxActivation = Double.MIN_VALUE;
        int maxActivationNeuron = -1;
        NeuralData output = this.network.compute(leastRepresented);
        for (int outputNeuron = 0; outputNeuron < won.length; ++outputNeuron) {
            if (won[outputNeuron] != 0 || maxActivationNeuron != -1 && !(output.getData(outputNeuron) > maxActivation)) continue;
            maxActivation = output.getData(outputNeuron);
            maxActivationNeuron = outputNeuron;
        }
        if (maxActivationNeuron != -1) {
            this.copyInputPattern(synapse, maxActivationNeuron, leastRepresented);
            return true;
        }
        return false;
    }

    public int getInputNeuronCount() {
        return this.inputNeuronCount;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    public NeighborhoodFunction getNeighborhood() {
        return this.neighborhood;
    }

    @Override
    public BasicNetwork getNetwork() {
        return this.network;
    }

    public int getOutputNeuronCount() {
        return this.outputNeuronCount;
    }

    public boolean isForceWinner() {
        return this.forceWinner;
    }

    @Override
    public void iteration() {
        if (this.logger.isInfoEnabled()) {
            this.logger.info("Performing Competitive Training iteration.");
        }
        this.preIteration();
        this.bmuUtil.reset();
        int[] won = new int[this.outputNeuronCount];
        double leastRepresentedActivation = Double.MAX_VALUE;
        NeuralData leastRepresented = null;
        for (Synapse synapse : this.synapses) {
            Matrix correction = this.correctionMatrix.get(synapse);
            correction.clear();
            for (NeuralDataPair pair : this.getTraining()) {
                NeuralData input = pair.getInput();
                int bmu = this.bmuUtil.calculateBMU(synapse, input);
                if (this.forceWinner) {
                    int n = bmu;
                    won[n] = won[n] + 1;
                    NeuralData output = this.network.compute(pair.getInput());
                    if (output.getData(bmu) < leastRepresentedActivation) {
                        leastRepresentedActivation = output.getData(bmu);
                        leastRepresented = pair.getInput();
                    }
                }
                this.train(bmu, synapse, input);
            }
            if (this.forceWinner) {
                if (this.forceWinners(synapse, won, leastRepresented)) continue;
                this.applyCorrection();
                continue;
            }
            this.applyCorrection();
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.setError(this.bmuUtil.getWorstDistance());
        this.postIteration();
    }

    public void setAutoDecay(int plannedIterations, double startRate, double endRate, double startRadius, double endRadius) {
        this.startRate = startRate;
        this.endRate = endRate;
        this.startRadius = startRadius;
        this.endRadius = endRadius;
        this.autoDecayRadius = (endRadius - startRadius) / (double)plannedIterations;
        this.autoDecayRate = (endRate - startRate) / (double)plannedIterations;
        this.setParams(this.startRate, this.startRadius);
    }

    public void setForceWinner(boolean forceWinner) {
        this.forceWinner = forceWinner;
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    public void setParams(double rate, double radius) {
        this.radius = radius;
        this.learningRate = rate;
        this.getNeighborhood().setRadius(radius);
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        result.append("Rate=");
        result.append(Format.formatPercent(this.learningRate));
        result.append(", Radius=");
        result.append(Format.formatDouble(this.radius, 2));
        return result.toString();
    }

    private void train(int bmu, Synapse synapse, NeuralData input) {
        for (int outputNeuron = 0; outputNeuron < this.outputNeuronCount; ++outputNeuron) {
            this.trainPattern(synapse, input, outputNeuron, bmu);
        }
    }

    public void trainPattern(NeuralData pattern) {
        for (Synapse synapse : this.synapses) {
            NeuralData input = pattern;
            int bmu = this.bmuUtil.calculateBMU(synapse, input);
            this.train(bmu, synapse, input);
        }
        this.applyCorrection();
    }

    private void trainPattern(Synapse synapse, NeuralData input, int current, int bmu) {
        Matrix correction = this.correctionMatrix.get(synapse);
        for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; ++inputNeuron) {
            double currentWeight = synapse.getMatrix().get(inputNeuron, current);
            double inputValue = input.getData(inputNeuron);
            double newWeight = this.determineNewWeight(currentWeight, inputValue, current, bmu);
            correction.set(inputNeuron, current, newWeight);
        }
    }
}

