/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.yarn.example.nn;

import com.google.common.base.Splitter;
import java.io.IOException;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.io.GuaguaRecordReader;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import ml.shifu.guagua.yarn.GuaguaLineRecordReader;
import ml.shifu.guagua.yarn.GuaguaWritableAdapter;
import ml.shifu.guagua.yarn.example.nn.Gradient;
import ml.shifu.guagua.yarn.example.nn.NNUtils;
import ml.shifu.guagua.yarn.example.nn.meta.NNParams;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NNWorker
extends AbstractWorkerComputable<NNParams, NNParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(NNWorker.class);
    private MLDataSet trainingData = null;
    private MLDataSet testingData = null;
    private Gradient gradient;
    private long count;
    private int inputs;
    private int hiddens;
    private int outputs;

    private void initMemoryDataSet() {
        this.trainingData = new BasicMLDataSet();
        this.testingData = new BasicMLDataSet();
    }

    public void init(WorkerContext<NNParams, NNParams> workerContext) {
        this.inputs = NumberFormatUtils.getInt((String)workerContext.getProps().getProperty("guagua.nn.input.nodes"), (int)100);
        this.hiddens = NumberFormatUtils.getInt((String)workerContext.getProps().getProperty("guagua.nn.hidden.nodes"), (int)2);
        this.outputs = NumberFormatUtils.getInt((String)workerContext.getProps().getProperty("guagua.nn.output.nodes"), (int)1);
        LOG.info("NNWorker is loading data into memory.");
        this.initMemoryDataSet();
    }

    public NNParams doCompute(WorkerContext<NNParams, NNParams> workerContext) {
        if (workerContext.getCurrentIteration() == 1) {
            return this.buildEmptyNNParams(workerContext);
        }
        if (workerContext.getLastMasterResult() == null) {
            LOG.warn("Master result of last iteration is null.");
            return null;
        }
        LOG.debug("Set current model with params {}", (Object)workerContext.getLastMasterResult());
        if (this.gradient == null) {
            this.initGradient(this.trainingData, ((NNParams)workerContext.getLastMasterResult()).getWeights());
        }
        this.gradient.setWeights(((NNParams)workerContext.getLastMasterResult()).getWeights());
        this.gradient.run();
        double trainError = this.gradient.getError();
        double testError = this.testingData.getRecordCount() > 0L ? this.gradient.getNetwork().calculateError(this.testingData) : 0.0;
        LOG.info("NNWorker compute iteration {} (train error {} validation error {})", new Object[]{workerContext.getCurrentIteration(), trainError, testError});
        NNParams params = new NNParams();
        params.setTestError(testError);
        params.setTrainError(trainError);
        params.setGradients(this.gradient.getGradients());
        params.setWeights(new double[0]);
        params.setTrainSize(this.trainingData.getRecordCount());
        return params;
    }

    private void initGradient(MLDataSet training, double[] weights) {
        BasicNetwork network = NNUtils.generateNetwork(this.inputs, this.hiddens, this.outputs);
        network.getFlat().setWeights(weights);
        FlatNetwork flat = network.getFlat();
        double[] flatSpot = new double[flat.getActivationFunctions().length];
        for (int i = 0; i < flat.getActivationFunctions().length; ++i) {
            flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0;
        }
        this.gradient = new Gradient(flat, training.openAdditional(), flatSpot, (ErrorFunction)new LinearErrorFunction());
    }

    private NNParams buildEmptyNNParams(WorkerContext<NNParams, NNParams> workerContext) {
        NNParams params = new NNParams();
        params.setWeights(new double[0]);
        params.setGradients(new double[0]);
        params.setTestError(0.0);
        params.setTrainError(0.0);
        return params;
    }

    protected void postLoad(WorkerContext<NNParams, NNParams> workerContext) {
        LOG.info("- # Records of the whole data set: {}.", (Object)this.count);
        LOG.info("- # Records of the training data set: {}.", (Object)this.trainingData.getRecordCount());
        LOG.info("- # Records of the testing data set: {}.", (Object)this.testingData.getRecordCount());
    }

    public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
        ++this.count;
        if (this.count % 100000L == 0L) {
            LOG.info("Read {} records.", (Object)this.count);
        }
        double[] ideal = new double[1];
        int inputNodes = NumberFormatUtils.getInt((String)workerContext.getProps().getProperty("guagua.nn.input.nodes"), (int)100);
        double[] inputs = new double[inputNodes];
        int i = 0;
        for (String input : Splitter.on((String)"|").split((CharSequence)((Text)currentValue.getWritable()).toString())) {
            int inputsIndex;
            if (i == 0) {
                ideal[i++] = NumberFormatUtils.getDouble((String)input, (double)0.0);
                continue;
            }
            if ((inputsIndex = i++ - 1) >= inputNodes) break;
            inputs[inputsIndex] = NumberFormatUtils.getDouble((String)input, (double)0.0);
        }
        if (i < inputNodes + 1) {
            throw new GuaguaRuntimeException(String.format("Not enough data columns, input nodes setting:%s, data column:%s", inputNodes, i));
        }
        int scale = NumberFormatUtils.getInt((String)workerContext.getProps().getProperty("nn.record.scale"), (int)1);
        for (int j = 0; j < scale; ++j) {
            double[] tmpInputs = j == 0 ? inputs : new double[inputs.length];
            double[] tmpIdeal = j == 0 ? ideal : new double[ideal.length];
            System.arraycopy(inputs, 0, tmpInputs, 0, inputs.length);
            BasicMLDataPair pair = new BasicMLDataPair((MLData)new BasicMLData(tmpInputs), (MLData)new BasicMLData(tmpIdeal));
            double r = Math.random();
            if (r >= 0.5) {
                this.trainingData.add((MLDataPair)pair);
                continue;
            }
            this.testingData.add((MLDataPair)pair);
        }
    }

    public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
        this.setRecordReader((GuaguaRecordReader)new GuaguaLineRecordReader());
        this.getRecordReader().initialize(fileSplit);
    }

    public MLDataSet getTrainingData() {
        return this.trainingData;
    }

    public void setTrainingData(MLDataSet trainingData) {
        this.trainingData = trainingData;
    }

    public MLDataSet getTestingData() {
        return this.testingData;
    }

    public void setTestingData(MLDataSet testingData) {
        this.testingData = testingData;
    }
}

