/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.pnn;

import org.encog.mathutil.EncogMath;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.pnn.AbstractPNN;
import org.encog.neural.pnn.PNNKernelType;
import org.encog.neural.pnn.PNNOutputMode;

public class BasicPNN
extends AbstractPNN
implements MLRegression {
    private static final long serialVersionUID = -7990707837655024635L;
    private final double[] sigma;
    private BasicMLDataSet samples;
    private int[] countPer;
    private double[] priors;

    public BasicPNN(PNNKernelType kernel, PNNOutputMode outmodel, int inputCount, int outputCount) {
        super(kernel, outmodel, inputCount, outputCount);
        this.setSeparateClass(false);
        this.sigma = new double[inputCount];
    }

    @Override
    public final MLData compute(MLData input) {
        double[] out;
        block18: {
            double psum;
            block17: {
                out = new double[this.getOutputCount()];
                psum = 0.0;
                int r = -1;
                for (MLDataPair pair : this.samples) {
                    int i;
                    if (++r == this.getExclude()) continue;
                    double dist = 0.0;
                    for (i = 0; i < this.getInputCount(); ++i) {
                        double diff = input.getData(i) - pair.getInput().getData(i);
                        dist += (diff /= this.sigma[i]) * diff;
                    }
                    if (this.getKernel() == PNNKernelType.Gaussian) {
                        dist = Math.exp(-dist);
                    } else if (this.getKernel() == PNNKernelType.Reciprocal) {
                        dist = 1.0 / (1.0 + dist);
                    }
                    if (dist < 1.0E-40) {
                        dist = 1.0E-40;
                    }
                    if (this.getOutputMode() == PNNOutputMode.Classification) {
                        int pop;
                        int n = pop = (int)pair.getIdeal().getData(0);
                        out[n] = out[n] + dist;
                        continue;
                    }
                    if (this.getOutputMode() == PNNOutputMode.Unsupervised) {
                        for (i = 0; i < this.getInputCount(); ++i) {
                            int n = i;
                            out[n] = out[n] + dist * pair.getInput().getData(i);
                        }
                        psum += dist;
                        continue;
                    }
                    if (this.getOutputMode() != PNNOutputMode.Regression) continue;
                    for (i = 0; i < this.getOutputCount(); ++i) {
                        int n = i;
                        out[n] = out[n] + dist * pair.getIdeal().getData(i);
                    }
                    psum += dist;
                }
                if (this.getOutputMode() == PNNOutputMode.Classification) {
                    int i;
                    psum = 0.0;
                    for (i = 0; i < this.getOutputCount(); ++i) {
                        if (this.priors[i] >= 0.0) {
                            int n = i;
                            out[n] = out[n] * (this.priors[i] / (double)this.countPer[i]);
                        }
                        psum += out[i];
                    }
                    if (psum < 1.0E-40) {
                        psum = 1.0E-40;
                    }
                    i = 0;
                    while (i < this.getOutputCount()) {
                        int n = i++;
                        out[n] = out[n] / psum;
                    }
                    BasicMLData result = new BasicMLData(1);
                    result.setData(0, EncogMath.maxIndex(out));
                    return result;
                }
                if (this.getOutputMode() != PNNOutputMode.Unsupervised) break block17;
                int i = 0;
                while (i < this.getInputCount()) {
                    int n = i++;
                    out[n] = out[n] / psum;
                }
                break block18;
            }
            if (this.getOutputMode() != PNNOutputMode.Regression) break block18;
            int i = 0;
            while (i < this.getOutputCount()) {
                int n = i++;
                out[n] = out[n] / psum;
            }
        }
        return new BasicMLData(out);
    }

    public final int[] getCountPer() {
        return this.countPer;
    }

    public final double[] getPriors() {
        return this.priors;
    }

    public final BasicMLDataSet getSamples() {
        return this.samples;
    }

    public final double[] getSigma() {
        return this.sigma;
    }

    public final void setSamples(BasicMLDataSet samples) {
        this.samples = samples;
        if (this.getOutputMode() == PNNOutputMode.Classification) {
            this.countPer = new int[this.getOutputCount()];
            this.priors = new double[this.getOutputCount()];
            for (MLDataPair pair : samples) {
                int i = (int)pair.getIdeal().getData(0);
                if (i >= this.countPer.length) {
                    throw new NeuralNetworkError("Training data contains more classes than neural network has output neurons to hold.");
                }
                int n = i;
                this.countPer[n] = this.countPer[n] + 1;
            }
            for (int i = 0; i < this.priors.length; ++i) {
                this.priors[i] = -1.0;
            }
        }
    }

    @Override
    public void updateProperties() {
    }
}

