/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.nn;

import java.util.concurrent.atomic.AtomicBoolean;
import ml.shifu.guagua.example.nn.NNUtils;
import ml.shifu.guagua.example.nn.Weight;
import ml.shifu.guagua.example.nn.meta.NNParams;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NNMaster
implements MasterComputable<NNParams, NNParams> {
    private static final Logger LOG = LoggerFactory.getLogger(NNMaster.class);
    private NNParams globalNNParams = new NNParams();
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Weight weightCalculator = null;
    private double learningRate;

    public NNParams compute(MasterContext<NNParams, NNParams> context) {
        if (this.isInitialized.compareAndSet(false, true)) {
            NNParams params = this.initWeights(context);
            this.globalNNParams.setWeights(params.getWeights());
            return params;
        }
        if (context.getWorkerResults() == null) {
            throw new IllegalArgumentException("workers' results are null.");
        }
        double totalTestError = 0.0;
        double totalTrainError = 0.0;
        int size = 0;
        this.globalNNParams.reset();
        for (NNParams nn : context.getWorkerResults()) {
            totalTestError += nn.getTestError();
            totalTrainError += nn.getTrainError();
            this.globalNNParams.accumulateGradients(nn.getGradients());
            this.globalNNParams.accumulateTrainSize(nn.getTrainSize());
            ++size;
        }
        if (size == 0) {
            throw new IllegalArgumentException("workers' results are empty.");
        }
        if (this.weightCalculator == null) {
            this.weightCalculator = new Weight(this.globalNNParams.getGradients().length, this.globalNNParams.getTrainSize(), this.learningRate, "Q");
        }
        double[] weights = this.weightCalculator.calculateWeights(this.globalNNParams.getWeights(), this.globalNNParams.getGradients());
        this.globalNNParams.setWeights(weights);
        double currentTestError = totalTestError / (double)size;
        double currentTrainError = totalTrainError / (double)size;
        LOG.info("NNMaster compute iteration {} ( avg train error {}, avg validation error {} )", new Object[]{context.getCurrentIteration(), currentTrainError, currentTestError});
        NNParams params = new NNParams();
        params.setTrainError(currentTrainError);
        params.setTestError(currentTestError);
        params.setGradients(new double[0]);
        params.setWeights(weights);
        LOG.debug("master result {} in iteration {}", (Object)params, (Object)context.getCurrentIteration());
        return params;
    }

    private NNParams initWeights(MasterContext<NNParams, NNParams> context) {
        int inputs = NumberFormatUtils.getInt((String)context.getProps().getProperty("guagua.nn.input.nodes"), (int)100);
        int hiddens = NumberFormatUtils.getInt((String)context.getProps().getProperty("guagua.nn.hidden.nodes"), (int)2);
        int outputs = NumberFormatUtils.getInt((String)context.getProps().getProperty("guagua.nn.output.nodes"), (int)1);
        this.learningRate = NumberFormatUtils.getDouble((String)context.getProps().getProperty("guagua.nn.learning.rate", "0.1"));
        BasicNetwork network = NNUtils.generateNetwork(inputs, hiddens, outputs);
        NNParams params = new NNParams();
        params.setTrainError(0.0);
        params.setTestError(0.0);
        params.setGradients(new double[0]);
        params.setWeights(network.getFlat().getWeights());
        return params;
    }
}

