/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.neat.training;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.genetic.GeneticAlgorithm;
import org.encog.ml.genetic.genome.Chromosome;
import org.encog.ml.genetic.genome.Genome;
import org.encog.ml.genetic.genome.GenomeComparator;
import org.encog.ml.genetic.population.Population;
import org.encog.ml.genetic.species.BasicSpecies;
import org.encog.ml.genetic.species.Species;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATPopulation;
import org.encog.neural.neat.training.NEATGenome;
import org.encog.neural.neat.training.NEATInnovationList;
import org.encog.neural.neat.training.NEATLinkGene;
import org.encog.neural.neat.training.NEATParams;
import org.encog.neural.neat.training.NEATParent;
import org.encog.neural.networks.training.CalculateScore;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.genetic.GeneticScoreAdapter;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public class NEATTraining
extends GeneticAlgorithm
implements MLTrain {
    private double averageFitAdjustment;
    private double bestEverScore;
    private NEATNetwork bestEverNetwork;
    private final int inputCount;
    private final int outputCount;
    private double totalFitAdjustment;
    private boolean snapshot;
    private int iteration;
    private final NEATParams params = new NEATParams();

    public NEATTraining(CalculateScore calculateScore, int inputCount, int outputCount, int populationSize) {
        this.inputCount = inputCount;
        this.outputCount = outputCount;
        this.setCalculateScore(new GeneticScoreAdapter(calculateScore));
        this.setComparator(new GenomeComparator(this.getCalculateScore()));
        this.setPopulation(new NEATPopulation(inputCount, outputCount, populationSize));
        this.init();
    }

    public NEATTraining(CalculateScore calculateScore, Population population) {
        if (population.size() < 1) {
            throw new TrainingError("Population can not be empty.");
        }
        NEATGenome genome = (NEATGenome)population.getGenomes().get(0);
        this.setCalculateScore(new GeneticScoreAdapter(calculateScore));
        this.setComparator(new GenomeComparator(this.getCalculateScore()));
        this.setPopulation(population);
        this.inputCount = genome.getInputCount();
        this.outputCount = genome.getOutputCount();
        this.init();
    }

    public void addNeuronID(long nodeID, List<Long> vec) {
        for (int i = 0; i < vec.size(); ++i) {
            if (vec.get(i) != nodeID) continue;
            return;
        }
        vec.add(nodeID);
    }

    @Override
    public void addStrategy(Strategy strategy) {
        throw new TrainingError("Strategies are not supported by this training method.");
    }

    public void adjustCompatibilityThreshold() {
        if (this.params.maxNumberOfSpecies < 1) {
            return;
        }
        double thresholdIncrement = 0.01;
        if (this.getPopulation().getSpecies().size() > this.params.maxNumberOfSpecies) {
            this.params.compatibilityThreshold += 0.01;
        } else if (this.getPopulation().getSpecies().size() < 2) {
            this.params.compatibilityThreshold -= 0.01;
        }
    }

    public void adjustSpeciesScore() {
        for (Species s : this.getPopulation().getSpecies()) {
            for (Genome member : s.getMembers()) {
                double score = member.getScore();
                if (s.getAge() < this.getPopulation().getYoungBonusAgeThreshold()) {
                    score = this.getComparator().applyBonus(score, this.getPopulation().getYoungScoreBonus());
                }
                if (s.getAge() > this.getPopulation().getOldAgeThreshold()) {
                    score = this.getComparator().applyPenalty(score, this.getPopulation().getOldAgePenalty());
                }
                double adjustedScore = score / (double)s.getMembers().size();
                member.setAdjustedScore(adjustedScore);
            }
        }
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    private NEATParent favorParent(NEATGenome mom, NEATGenome dad) {
        if (mom.getScore() == dad.getScore()) {
            if (mom.getNumGenes() == dad.getNumGenes()) {
                if (Math.random() > 0.0) {
                    return NEATParent.Mom;
                }
                return NEATParent.Dad;
            }
            if (mom.getNumGenes() < dad.getNumGenes()) {
                return NEATParent.Mom;
            }
            return NEATParent.Dad;
        }
        if (this.getComparator().isBetterThan(mom.getScore(), dad.getScore())) {
            return NEATParent.Mom;
        }
        return NEATParent.Dad;
    }

    public NEATGenome crossover(NEATGenome mom, NEATGenome dad) {
        NEATParent best = this.favorParent(mom, dad);
        Chromosome babyNeurons = new Chromosome();
        Chromosome babyGenes = new Chromosome();
        ArrayList<Long> vecNeurons = new ArrayList<Long>();
        int curMom = 0;
        int curDad = 0;
        NEATLinkGene selectedGene = null;
        while (curMom < mom.getNumGenes() || curDad < dad.getNumGenes()) {
            NEATLinkGene momGene = null;
            NEATLinkGene dadGene = null;
            if (curMom < mom.getNumGenes()) {
                momGene = (NEATLinkGene)mom.getLinks().get(curMom);
            }
            if (curDad < dad.getNumGenes()) {
                dadGene = (NEATLinkGene)dad.getLinks().get(curDad);
            }
            if (momGene == null && dadGene != null) {
                if (best == NEATParent.Dad) {
                    selectedGene = dadGene;
                }
                ++curDad;
            } else if (dadGene == null && momGene != null) {
                if (best == NEATParent.Mom) {
                    selectedGene = momGene;
                }
                ++curMom;
            } else if (momGene.getInnovationId() < dadGene.getInnovationId()) {
                if (best == NEATParent.Mom) {
                    selectedGene = momGene;
                }
                ++curMom;
            } else if (dadGene.getInnovationId() < momGene.getInnovationId()) {
                if (best == NEATParent.Dad) {
                    selectedGene = dadGene;
                }
                ++curDad;
            } else if (dadGene.getInnovationId() == momGene.getInnovationId()) {
                selectedGene = Math.random() < 0.5 ? momGene : dadGene;
                ++curMom;
                ++curDad;
            }
            if (babyGenes.size() == 0) {
                babyGenes.add(selectedGene);
            } else if (((NEATLinkGene)babyGenes.get(babyGenes.size() - 1)).getInnovationId() != selectedGene.getInnovationId()) {
                babyGenes.add(selectedGene);
            }
            this.addNeuronID(selectedGene.getFromNeuronID(), vecNeurons);
            this.addNeuronID(selectedGene.getToNeuronID(), vecNeurons);
        }
        Collections.sort(vecNeurons);
        for (int i = 0; i < vecNeurons.size(); ++i) {
            babyNeurons.add(this.getInnovations().createNeuronFromID((Long)vecNeurons.get(i)));
        }
        NEATGenome babyGenome = new NEATGenome(this.getPopulation().assignGenomeID(), babyNeurons, babyGenes, mom.getInputCount(), mom.getOutputCount());
        babyGenome.setGeneticAlgorithm(this);
        babyGenome.setPopulation(this.getPopulation());
        babyGenome.validate();
        return babyGenome;
    }

    @Override
    public void finishTraining() {
    }

    @Override
    public double getError() {
        return this.bestEverScore;
    }

    @Override
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    public NEATInnovationList getInnovations() {
        return (NEATInnovationList)this.getPopulation().getInnovations();
    }

    public int getInputCount() {
        return this.inputCount;
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    @Override
    public MLMethod getMethod() {
        return this.bestEverNetwork;
    }

    public int getOutputCount() {
        return this.outputCount;
    }

    @Override
    public List<Strategy> getStrategies() {
        return new ArrayList<Strategy>();
    }

    @Override
    public MLDataSet getTraining() {
        return null;
    }

    private void init() {
        this.bestEverScore = this.getCalculateScore().shouldMinimize() ? Double.MAX_VALUE : Double.MIN_VALUE;
        for (Genome obj : this.getPopulation().getGenomes()) {
            if (!(obj instanceof NEATGenome)) {
                throw new TrainingError("Population can only contain objects of NEATGenome.");
            }
            NEATGenome neat = (NEATGenome)obj;
            if (neat.getInputCount() != this.inputCount || neat.getOutputCount() != this.outputCount) {
                throw new TrainingError("All NEATGenome's must have the same input and output sizes as the base network.");
            }
            neat.setGeneticAlgorithm(this);
        }
        this.getPopulation().claim(this);
        this.resetAndKill();
        this.sortAndRecord();
        this.speciateAndCalculateSpawnLevels();
    }

    public boolean isSnapshot() {
        return this.snapshot;
    }

    @Override
    public boolean isTrainingDone() {
        return false;
    }

    @Override
    public void iteration() {
        ++this.iteration;
        ArrayList<NEATGenome> newPop = new ArrayList<NEATGenome>();
        int numSpawnedSoFar = 0;
        for (Species s : this.getPopulation().getSpecies()) {
            if (numSpawnedSoFar >= this.getPopulation().size()) continue;
            int numToSpawn = (int)Math.round(s.getNumToSpawn());
            boolean bChosenBestYet = false;
            while (numToSpawn-- > 0) {
                NEATGenome baby = null;
                if (!bChosenBestYet) {
                    baby = (NEATGenome)s.getLeader();
                    bChosenBestYet = true;
                } else {
                    if (s.getMembers().size() == 1) {
                        baby = new NEATGenome((NEATGenome)s.chooseParent());
                    } else {
                        NEATGenome g1 = (NEATGenome)s.chooseParent();
                        if (Math.random() < this.params.crossoverRate) {
                            NEATGenome g2 = (NEATGenome)s.chooseParent();
                            int numAttempts = 5;
                            while (g1.getGenomeID() == g2.getGenomeID() && numAttempts-- > 0) {
                                g2 = (NEATGenome)s.chooseParent();
                            }
                            if (g1.getGenomeID() != g2.getGenomeID()) {
                                baby = this.crossover(g1, g2);
                            }
                        } else {
                            baby = new NEATGenome(g1);
                        }
                    }
                    if (baby != null) {
                        baby.setGenomeID(this.getPopulation().assignGenomeID());
                        if ((double)baby.getNeurons().size() < this.params.maxPermittedNeurons) {
                            baby.addNeuron(this.params.chanceAddNode, this.params.numTrysToFindOldLink);
                        }
                        baby.addLink(this.params.chanceAddLink, this.params.chanceAddRecurrentLink, this.params.numTrysToFindLoopedLink, this.params.numAddLinkAttempts);
                        baby.mutateWeights(this.params.mutationRate, this.params.probabilityWeightReplaced, this.params.maxWeightPerturbation);
                        baby.mutateActivationResponse(this.params.activationMutationRate, this.params.maxActivationPerturbation);
                    }
                }
                if (baby == null) continue;
                baby.sortGenes();
                newPop.add(baby);
                if (++numSpawnedSoFar != this.getPopulation().size()) continue;
                numToSpawn = 0;
            }
        }
        while (newPop.size() < this.getPopulation().size()) {
            newPop.add(this.tournamentSelection(this.getPopulation().size() / 5));
        }
        this.getPopulation().clear();
        this.getPopulation().addAll(newPop);
        this.resetAndKill();
        this.sortAndRecord();
        this.speciateAndCalculateSpawnLevels();
    }

    @Override
    public void iteration(int count) {
        for (int i = 0; i < count; ++i) {
            this.iteration();
        }
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    public void resetAndKill() {
        Object[] speciesArray;
        this.totalFitAdjustment = 0.0;
        this.averageFitAdjustment = 0.0;
        for (Object element : speciesArray = this.getPopulation().getSpecies().toArray()) {
            Species s = (Species)element;
            s.purge();
            if (!this.getPopulation().getGenomes().contains(s.getLeader())) {
                this.getPopulation().getSpecies().remove(s);
                continue;
            }
            if (s.getGensNoImprovement() <= this.params.numGensAllowedNoImprovement || !this.getComparator().isBetterThan(this.bestEverScore, s.getBestScore())) continue;
            this.getPopulation().getSpecies().remove(s);
        }
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setError(double error) {
    }

    @Override
    public void setIteration(int iteration) {
        this.iteration = iteration;
    }

    public void setSnapshot(boolean snapshot) {
        this.snapshot = snapshot;
    }

    public void sortAndRecord() {
        for (Genome genome : this.getPopulation().getGenomes()) {
            genome.decode();
            this.calculateScore(genome);
        }
        this.getPopulation().sort();
        Genome genome = this.getPopulation().getBest();
        double currentBest = genome.getScore();
        if (this.getComparator().isBetterThan(currentBest, this.bestEverScore)) {
            this.bestEverScore = currentBest;
            this.bestEverNetwork = (NEATNetwork)genome.getOrganism();
        }
        this.bestEverScore = this.getComparator().bestScore(this.getError(), this.bestEverScore);
    }

    public void speciateAndCalculateSpawnLevels() {
        NEATGenome genome;
        this.adjustCompatibilityThreshold();
        for (Genome g : this.getPopulation().getGenomes()) {
            genome = (NEATGenome)g;
            boolean added = false;
            for (Species s : this.getPopulation().getSpecies()) {
                double compatibility = genome.getCompatibilityScore((NEATGenome)s.getLeader());
                if (!(compatibility <= this.params.compatibilityThreshold)) continue;
                this.addSpeciesMember(s, genome);
                genome.setSpeciesID(s.getSpeciesID());
                added = true;
                break;
            }
            if (added) continue;
            this.getPopulation().getSpecies().add(new BasicSpecies(this.getPopulation(), genome, this.getPopulation().assignSpeciesID()));
        }
        this.adjustSpeciesScore();
        for (Genome g : this.getPopulation().getGenomes()) {
            genome = (NEATGenome)g;
            this.totalFitAdjustment += genome.getAdjustedScore();
        }
        this.averageFitAdjustment = this.totalFitAdjustment / (double)this.getPopulation().size();
        for (Genome g : this.getPopulation().getGenomes()) {
            genome = (NEATGenome)g;
            double toSpawn = genome.getAdjustedScore() / this.averageFitAdjustment;
            genome.setAmountToSpawn(toSpawn);
        }
        for (Species species : this.getPopulation().getSpecies()) {
            species.calculateSpawnAmount();
        }
    }

    public NEATGenome tournamentSelection(int numComparisons) {
        double bestScoreSoFar = 0.0;
        int ChosenOne = 0;
        for (int i = 0; i < numComparisons; ++i) {
            int ThisTry = (int)RangeRandomizer.randomize(0.0, this.getPopulation().size() - 1);
            if (!(this.getPopulation().get(ThisTry).getScore() > bestScoreSoFar)) continue;
            ChosenOne = ThisTry;
            bestScoreSoFar = this.getPopulation().get(ThisTry).getScore();
        }
        return (NEATGenome)this.getPopulation().get(ChosenOne);
    }
}

