/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.mapreduce.example.lnr;

import com.google.common.base.Splitter;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.io.GuaguaRecordReader;
import ml.shifu.guagua.mapreduce.GuaguaLineRecordReader;
import ml.shifu.guagua.mapreduce.GuaguaWritableAdapter;
import ml.shifu.guagua.mapreduce.example.lnr.LinearRegressionParams;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LinearRegressionWorker
extends AbstractWorkerComputable<LinearRegressionParams, LinearRegressionParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionWorker.class);
    private int inputNum;
    private int outputNum;
    private List<Data> dataList;
    private double[] weights;
    private Splitter splitter = Splitter.on((String)",");

    public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
        this.setRecordReader((GuaguaRecordReader)new GuaguaLineRecordReader(fileSplit));
    }

    public void init(WorkerContext<LinearRegressionParams, LinearRegressionParams> context) {
        this.inputNum = NumberFormatUtils.getInt((String)"lr.input.num", (int)2);
        this.outputNum = 1;
        this.dataList = new LinkedList<Data>();
    }

    public LinearRegressionParams doCompute(WorkerContext<LinearRegressionParams, LinearRegressionParams> context) {
        if (context.isFirstIteration()) {
            return new LinearRegressionParams();
        }
        this.weights = ((LinearRegressionParams)context.getLastMasterResult()).getParameters();
        double[] gradients = new double[this.inputNum + 1];
        double finalError = 0.0;
        int size = 0;
        for (Data data : this.dataList) {
            double error = this.dot(data.inputs, this.weights) - data.outputs[0];
            finalError += error * error / 2.0;
            for (int i = 0; i < gradients.length; ++i) {
                int n = i;
                gradients[n] = gradients[n] + error * data.inputs[i];
            }
            ++size;
        }
        LOG.info("Iteration {} with error {}", (Object)context.getCurrentIteration(), (Object)(finalError / (double)size));
        return new LinearRegressionParams(gradients, finalError / (double)size);
    }

    private double dot(double[] inputs, double[] weights) {
        double value = 0.0;
        for (int i = 0; i < weights.length; ++i) {
            value += weights[i] * inputs[i];
        }
        return value;
    }

    public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<LinearRegressionParams, LinearRegressionParams> context) {
        String line = ((Text)currentValue.getWritable()).toString();
        double[] inputData = new double[this.inputNum + 1];
        double[] outputData = new double[this.outputNum];
        int count = 0;
        int inputIndex = 0;
        int outputIndex = 0;
        inputData[inputIndex++] = 1.0;
        for (String unit : this.splitter.split((CharSequence)line)) {
            if (count < this.inputNum) {
                inputData[inputIndex++] = Double.valueOf(unit);
            } else {
                if (count < this.inputNum || count >= this.inputNum + this.outputNum) break;
                outputData[outputIndex++] = Double.valueOf(unit);
            }
            ++count;
        }
        this.dataList.add(new Data(inputData, outputData));
    }

    private static class Data {
        private final double[] inputs;
        private final double[] outputs;

        public Data(double[] inputs, double[] outputs) {
            this.inputs = inputs;
            this.outputs = outputs;
        }
    }
}

