/*
 * Decompiled with CFR 0.152.
 */
package org.encog.mathutil.randomize;

import org.encog.engine.network.activation.ActivationReLU;
import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLMethod;
import org.encog.neural.networks.BasicNetwork;

public class XaiverRandomizer
implements Randomizer {
    private double y2;
    private boolean useLast = false;
    private GenerateRandom rnd;

    public XaiverRandomizer() {
        this(System.currentTimeMillis());
    }

    public XaiverRandomizer(long seed) {
        this.rnd = new MersenneTwisterGenerateRandom(seed);
    }

    @Override
    public double randomize(double d) {
        return this.rnd.nextDouble();
    }

    public void randomize(BasicNetwork network, int fromLayer) {
        int fromCount = network.getLayerNeuronCount(fromLayer);
        int toCount = network.getLayerNeuronCount(fromLayer + 1);
        for (int fromNeuron = 0; fromNeuron < fromCount; ++fromNeuron) {
            int toNeuron;
            for (toNeuron = 0; toNeuron < toCount; ++toNeuron) {
                network.setWeight(fromLayer, fromCount, toNeuron, 0.0);
            }
            for (toNeuron = 0; toNeuron < toCount; ++toNeuron) {
                double d = network.getActivation(fromLayer) instanceof ActivationReLU ? 2.0 / Math.sqrt(fromCount) : 2.0 / Math.sqrt(fromCount + toCount);
                double w = this.rnd.nextDouble(-d, d);
                network.setWeight(fromLayer, fromNeuron, toNeuron, w);
            }
        }
    }

    @Override
    public void randomize(MLMethod method) {
        BasicNetwork network = (BasicNetwork)method;
        for (int i = 0; i < network.getLayerCount() - 1; ++i) {
            this.randomize(network, i);
        }
    }

    @Override
    public void randomize(double[] d) {
        this.randomize(d, 0, d.length);
    }

    @Override
    public void randomize(double[][] d) {
        for (int i = 0; i < d.length; ++i) {
            for (int j = 0; j < d[j].length; ++j) {
                d[i][j] = this.rnd.nextDouble();
            }
        }
    }

    @Override
    public void randomize(Matrix m) {
        this.randomize(m.getData());
    }

    @Override
    public void randomize(double[] d, int begin, int size) {
        for (int i = 0; i < size; ++i) {
            d[begin + i] = this.rnd.nextDouble();
        }
    }

    @Override
    public void setRandom(GenerateRandom theRandom) {
        this.rnd = theRandom;
    }

    @Override
    public GenerateRandom getRandom() {
        return this.rnd;
    }
}

