/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.prune;

import java.util.ArrayList;
import java.util.List;
import org.encog.EncogError;
import org.encog.StatusReportable;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.buffer.BufferedMLDataSet;
import org.encog.ml.train.strategy.StopTrainingStrategy;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.pattern.NeuralNetworkPattern;
import org.encog.neural.prune.HiddenLayerParams;
import org.encog.util.concurrency.job.ConcurrentJob;
import org.encog.util.concurrency.job.JobUnitContext;
import org.encog.util.logging.EncogLogging;

public class PruneIncremental
extends ConcurrentJob {
    private boolean done = false;
    private final MLDataSet training;
    private final NeuralNetworkPattern pattern;
    private final List<HiddenLayerParams> hidden = new ArrayList<HiddenLayerParams>();
    private final int iterations;
    private final BasicNetwork[] topNetworks;
    private final double[] topErrors;
    private BasicNetwork bestNetwork;
    private int currentTry;
    private final StatusReportable report;
    private int[] hiddenCounts;
    private double high;
    private double low;
    private double[][] results;
    private int hidden1Size;
    private int hidden2Size;
    private final int weightTries;

    public static String networkToString(BasicNetwork network) {
        if (network != null) {
            StringBuilder result = new StringBuilder();
            int num = 1;
            for (int i = 1; i < network.getLayerCount() - 1; ++i) {
                if (result.length() > 0) {
                    result.append(",");
                }
                result.append("H");
                result.append(num++);
                result.append("=");
                result.append(network.getLayerNeuronCount(i));
            }
            return result.toString();
        }
        return "N/A";
    }

    public PruneIncremental(MLDataSet training, NeuralNetworkPattern pattern, int iterations, int weightTries, int numTopResults, StatusReportable report) {
        super(report);
        if (numTopResults < 1) {
            throw new EncogError("There must be at least one best network.  numTopResults must be >0.");
        }
        this.training = training;
        this.pattern = pattern;
        this.iterations = iterations;
        this.report = report;
        this.weightTries = weightTries;
        this.topNetworks = new BasicNetwork[numTopResults];
        this.topErrors = new double[numTopResults];
    }

    public final void addHiddenLayer(int min, int max) {
        HiddenLayerParams param = new HiddenLayerParams(min, max);
        this.hidden.add(param);
    }

    private BasicNetwork generateNetwork() {
        this.pattern.clear();
        for (int element : this.hiddenCounts) {
            if (element <= 0) continue;
            this.pattern.addHiddenLayer(element);
        }
        return (BasicNetwork)this.pattern.generate();
    }

    public final BasicNetwork getBestNetwork() {
        return this.bestNetwork;
    }

    public final List<HiddenLayerParams> getHidden() {
        return this.hidden;
    }

    public final int getHidden1Size() {
        return this.hidden1Size;
    }

    public final int getHidden2Size() {
        return this.hidden2Size;
    }

    public final double getHigh() {
        return this.high;
    }

    public final int getIterations() {
        return this.iterations;
    }

    public final double getLow() {
        return this.low;
    }

    public final NeuralNetworkPattern getPattern() {
        return this.pattern;
    }

    public final double[][] getResults() {
        return this.results;
    }

    public final double[] getTopErrors() {
        return this.topErrors;
    }

    public final BasicNetwork[] getTopNetworks() {
        return this.topNetworks;
    }

    public final MLDataSet getTraining() {
        return this.training;
    }

    private boolean increaseHiddenCounts() {
        int i = 0;
        do {
            HiddenLayerParams param = this.hidden.get(i);
            int n = i;
            this.hiddenCounts[n] = this.hiddenCounts[n] + 1;
            if (this.hiddenCounts[i] <= param.getMax()) {
                return true;
            }
            this.hiddenCounts[i] = param.getMin();
        } while (++i < this.hiddenCounts.length);
        return false;
    }

    public final void init() {
        if (this.hidden.size() == 1) {
            this.hidden1Size = this.hidden.get(0).getMax() - this.hidden.get(0).getMin() + 1;
            this.hidden2Size = 0;
            this.results = new double[this.hidden1Size][1];
        } else if (this.hidden.size() == 2) {
            this.hidden1Size = this.hidden.get(0).getMax() - this.hidden.get(0).getMin() + 1;
            this.hidden2Size = this.hidden.get(1).getMax() - this.hidden.get(1).getMin() + 1;
            this.results = new double[this.hidden1Size][this.hidden2Size];
        } else {
            this.hidden1Size = 0;
            this.hidden2Size = 0;
            this.results = null;
        }
        this.high = Double.NEGATIVE_INFINITY;
        this.low = Double.POSITIVE_INFINITY;
    }

    @Override
    public final int loadWorkload() {
        int result = 1;
        for (HiddenLayerParams param : this.hidden) {
            result *= param.getMax() - param.getMin() + 1;
        }
        this.init();
        return result;
    }

    @Override
    public final void performJobUnit(JobUnitContext context) {
        BasicNetwork network = (BasicNetwork)context.getJobUnit();
        BufferedMLDataSet buffer = null;
        MLDataSet useTraining = this.training;
        if (this.training instanceof BufferedMLDataSet) {
            buffer = (BufferedMLDataSet)this.training;
            useTraining = buffer.openAdditional();
        }
        double error = Double.POSITIVE_INFINITY;
        for (int z = 0; z < this.weightTries; ++z) {
            network.reset();
            ResilientPropagation train = new ResilientPropagation(network, useTraining);
            StopTrainingStrategy strat = new StopTrainingStrategy(0.001, 5);
            train.addStrategy(strat);
            train.setThreadCount(1);
            for (int i = 0; i < this.iterations && !this.getShouldStop() && !strat.shouldStop(); ++i) {
                train.iteration();
            }
            error = Math.min(error, train.getError());
        }
        if (buffer != null) {
            buffer.close();
        }
        if (!this.getShouldStop()) {
            this.high = Math.max(this.high, error);
            this.low = Math.min(this.low, error);
            if (this.hidden1Size > 0) {
                int col;
                int row;
                int networkHidden1Count;
                int networkHidden2Count;
                if (network.getLayerCount() > 3) {
                    networkHidden2Count = network.getLayerNeuronCount(2);
                    networkHidden1Count = network.getLayerNeuronCount(1);
                } else {
                    networkHidden2Count = 0;
                    networkHidden1Count = network.getLayerNeuronCount(1);
                }
                if (this.hidden2Size == 0) {
                    row = networkHidden1Count - this.hidden.get(0).getMin();
                    col = 0;
                } else {
                    row = networkHidden1Count - this.hidden.get(0).getMin();
                    col = networkHidden2Count - this.hidden.get(1).getMin();
                }
                if (row < 0 || col < 0) {
                    System.out.println("STOP");
                }
                this.results[row][col] = error;
            }
            ++this.currentTry;
            this.updateBest(network, error);
            this.reportStatus(context, "Current: " + PruneIncremental.networkToString(network) + "; Best: " + PruneIncremental.networkToString(this.bestNetwork));
        }
    }

    @Override
    public final void process() {
        if (this.hidden.size() == 0) {
            throw new EncogError("To calculate the optimal hidden size, at least one hidden layer must be defined.");
        }
        this.hiddenCounts = new int[this.hidden.size()];
        this.bestNetwork = null;
        int i = 0;
        for (HiddenLayerParams parm : this.hidden) {
            this.hiddenCounts[i++] = parm.getMin();
        }
        if (this.hiddenCounts[0] == 0) {
            throw new EncogError("To calculate the optimal hidden size, at least one neuron must be the minimum for the first hidden layer.");
        }
        super.process();
    }

    @Override
    public final Object requestNextTask() {
        if (this.done || this.getShouldStop()) {
            return null;
        }
        BasicNetwork network = this.generateNetwork();
        if (!this.increaseHiddenCounts()) {
            this.done = true;
        }
        return network;
    }

    private synchronized void updateBest(BasicNetwork network, double error) {
        this.high = Math.max(this.high, error);
        this.low = Math.min(this.low, error);
        int selectedIndex = -1;
        for (int i = 0; i < this.topNetworks.length; ++i) {
            if (this.topNetworks[i] == null) {
                selectedIndex = i;
                break;
            }
            if (!(this.topErrors[i] > error) || selectedIndex != -1 && !(this.topErrors[selectedIndex] < this.topErrors[i])) continue;
            selectedIndex = i;
        }
        if (selectedIndex != -1) {
            this.topErrors[selectedIndex] = error;
            this.topNetworks[selectedIndex] = network;
        }
        BasicNetwork choice = null;
        for (BasicNetwork n : this.topNetworks) {
            if (n == null) continue;
            if (choice == null) {
                choice = n;
                continue;
            }
            if (n.getStructure().calculateSize() >= choice.getStructure().calculateSize()) continue;
            choice = n;
        }
        if (choice != this.bestNetwork) {
            EncogLogging.log(0, "Prune found new best network: error=" + error + ", network=" + choice);
            this.bestNetwork = choice;
        }
    }
}

