/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.pso;

import org.encog.mathutil.VectorAlgebra;
import org.encog.mathutil.randomize.NguyenWidrowRandomizer;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.ml.CalculateScore;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.TrainingSetScore;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.pso.NeuralPSOWorker;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.TaskGroup;

public class NeuralPSO
extends BasicTraining {
    protected boolean m_multiThreaded = true;
    protected VectorAlgebra m_va;
    protected CalculateScore m_calculateScore;
    protected Randomizer m_randomizer;
    protected BasicNetwork[] m_networks;
    protected double[][] m_velocities;
    protected double[][] m_bestVectors;
    protected double[] m_bestErrors;
    protected int m_bestVectorIndex;
    private double[] m_bestVector;
    BasicNetwork m_bestNetwork = null;
    protected int m_populationSize = 30;
    protected double m_maxPosition = -1.0;
    protected double m_maxVelocity = 2.0;
    protected double m_c1 = 2.0;
    protected double m_c2 = 2.0;
    protected double m_inertiaWeight = 0.4;
    private boolean m_pseudoAsynchronousUpdate = false;

    public NeuralPSO(BasicNetwork network, Randomizer randomizer, CalculateScore calculateScore, int populationSize) {
        super(TrainingImplementationType.Iterative);
        this.m_populationSize = populationSize;
        this.m_randomizer = randomizer;
        this.m_calculateScore = calculateScore;
        this.m_bestNetwork = network;
        this.m_networks = new BasicNetwork[this.m_populationSize];
        this.m_velocities = null;
        this.m_bestVectors = new double[this.m_populationSize][];
        this.m_bestErrors = new double[this.m_populationSize];
        this.m_bestVectorIndex = -1;
        this.m_bestVector = NetworkCODEC.networkToArray(this.m_bestNetwork);
        this.m_va = new VectorAlgebra();
    }

    public NeuralPSO(BasicNetwork network, MLDataSet trainingSet) {
        this(network, new NguyenWidrowRandomizer(), new TrainingSetScore(trainingSet), 20);
    }

    void initPopulation() {
        if (this.m_velocities == null) {
            int dimensionality = this.m_bestVector.length;
            this.m_velocities = new double[this.m_populationSize][dimensionality];
            this.iterationPSO(true);
        }
    }

    @Override
    public void iteration() {
        this.initPopulation();
        this.preIteration();
        this.iterationPSO(false);
        this.postIteration();
    }

    protected void iterationPSO(boolean init) {
        TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
        for (int i = 0; i < this.m_populationSize; ++i) {
            NeuralPSOWorker worker = new NeuralPSOWorker(this, i, init);
            if (!init && this.isMultiThreaded()) {
                EngineConcurrency.getInstance().processTask(worker, group);
                continue;
            }
            worker.run();
        }
        if (this.isMultiThreaded()) {
            group.waitForComplete();
        }
        this.updateGlobalBestPosition();
    }

    protected void updateParticle(int particleIndex, boolean init) {
        int i = particleIndex;
        double[] particlePosition = null;
        if (init) {
            if (this.m_networks[i] == null) {
                this.m_networks[i] = (BasicNetwork)this.m_bestNetwork.clone();
                if (i > 0) {
                    this.m_randomizer.randomize(this.m_networks[i]);
                }
            }
            particlePosition = this.getNetworkState(i);
            this.m_bestVectors[i] = particlePosition;
            this.m_va.randomise(this.m_velocities[i], this.m_maxVelocity);
        } else {
            particlePosition = this.getNetworkState(i);
            this.updateVelocity(i, particlePosition);
            this.m_va.clampComponents(this.m_velocities[i], this.m_maxVelocity);
            this.m_va.add(particlePosition, this.m_velocities[i]);
            this.m_va.clampComponents(particlePosition, this.m_maxPosition);
            this.setNetworkState(i, particlePosition);
        }
        this.updatePersonalBestPosition(i, particlePosition);
    }

    protected void updateVelocity(int particleIndex, double[] particlePosition) {
        int i = particleIndex;
        double[] vtmp = new double[particlePosition.length];
        this.m_va.mul(this.m_velocities[i], this.m_inertiaWeight);
        this.m_va.copy(vtmp, this.m_bestVectors[i]);
        this.m_va.sub(vtmp, particlePosition);
        this.m_va.mulRand(vtmp, this.m_c1);
        this.m_va.add(this.m_velocities[i], vtmp);
        if (i != this.m_bestVectorIndex) {
            this.m_va.copy(vtmp, this.m_pseudoAsynchronousUpdate ? this.m_bestVectors[this.m_bestVectorIndex] : this.m_bestVector);
            this.m_va.sub(vtmp, particlePosition);
            this.m_va.mulRand(vtmp, this.m_c2);
            this.m_va.add(this.m_velocities[i], vtmp);
        }
    }

    protected void updatePersonalBestPosition(int particleIndex, double[] particlePosition) {
        double score = this.m_calculateScore.calculateScore(this.m_networks[particleIndex]);
        if (this.m_bestErrors[particleIndex] == 0.0 || this.isScoreBetter(score, this.m_bestErrors[particleIndex])) {
            this.m_bestErrors[particleIndex] = score;
            this.m_va.copy(this.m_bestVectors[particleIndex], particlePosition);
        }
    }

    protected void updateGlobalBestPosition() {
        boolean bestUpdated = false;
        double currentBestError = this.getError();
        for (int i = 0; i < this.m_populationSize; ++i) {
            if (this.m_bestVectorIndex != -1 && !this.isScoreBetter(this.m_bestErrors[i], currentBestError)) continue;
            this.m_bestVectorIndex = i;
            bestUpdated = true;
            currentBestError = this.m_bestErrors[i];
        }
        if (bestUpdated) {
            this.m_va.copy(this.m_bestVector, this.m_bestVectors[this.m_bestVectorIndex]);
            this.m_bestNetwork.decodeFromArray(this.m_bestVector);
            this.setError(this.m_bestErrors[this.m_bestVectorIndex]);
        }
    }

    boolean isScoreBetter(double score1, double score2) {
        return this.m_calculateScore.shouldMinimize() && score1 < score2 || !this.m_calculateScore.shouldMinimize() && score1 > score2;
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    protected double[] getNetworkState(int particleIndex) {
        return NetworkCODEC.networkToArray(this.m_networks[particleIndex]);
    }

    protected void setNetworkState(int particleIndex, double[] state) {
        NetworkCODEC.arrayToNetwork(state, this.m_networks[particleIndex]);
    }

    public void setPopulationSize(int populationSize) {
        this.m_populationSize = populationSize;
    }

    public int getPopulationSize() {
        return this.m_populationSize;
    }

    public void setMaxVelocity(double maxVelocity) {
        this.m_maxVelocity = maxVelocity;
    }

    public double getMaxVelocity() {
        return this.m_maxVelocity;
    }

    public void setMaxPosition(double maxPosition) {
        this.m_maxPosition = maxPosition;
    }

    public double getMaxPosition() {
        return this.m_maxPosition;
    }

    public void setC1(double c1) {
        this.m_c1 = c1;
    }

    public double getC1() {
        return this.m_c1;
    }

    public void setC2(double c2) {
        this.m_c2 = c2;
    }

    public double getC2() {
        return this.m_c2;
    }

    public void setInertiaWeight(double inertiaWeight) {
        this.m_inertiaWeight = inertiaWeight;
    }

    public double getInertiaWeight() {
        return this.m_inertiaWeight;
    }

    public String getDescription() {
        return String.format("pop = %d, w = %.2f, c1 = %.2f, c2 = %.2f, Xmax = %.2f, Vmax = %.2f", this.m_populationSize, this.m_inertiaWeight, this.m_c1, this.m_c2, this.m_maxPosition, this.m_maxVelocity);
    }

    @Override
    public MLMethod getMethod() {
        return this.m_bestNetwork;
    }

    public void setInitialPopulation(BasicNetwork[] initialPopulation) {
        this.m_networks = initialPopulation;
    }

    public boolean isMultiThreaded() {
        return this.m_multiThreaded;
    }
}

