/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.lr;

import com.google.common.base.Splitter;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import ml.shifu.guagua.example.lr.LogisticRegressionParams;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.io.GuaguaRecordReader;
import ml.shifu.guagua.util.MemoryDiskList;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LogisticRegressionWorker
extends AbstractWorkerComputable<LogisticRegressionParams, LogisticRegressionParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionWorker.class);
    private int inputNum;
    private int outputNum;
    private MemoryDiskList<Data> dataList;
    private double[] weights;
    private Splitter splitter = Splitter.on((String)",");

    public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
        this.setRecordReader((GuaguaRecordReader)new GuaguaLineRecordReader(fileSplit));
    }

    public void init(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
        this.inputNum = NumberFormatUtils.getInt((String)"lr.input.num", (int)2);
        this.outputNum = 1;
        double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.5"));
        String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", System.getProperty("user.dir"));
        this.dataList = new MemoryDiskList((long)((double)Runtime.getRuntime().maxMemory() * memoryFraction), tmpFolder + File.separator + System.currentTimeMillis());
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable(){

            @Override
            public void run() {
                LogisticRegressionWorker.this.dataList.close();
                LogisticRegressionWorker.this.dataList.clear();
            }
        }));
    }

    public LogisticRegressionParams doCompute(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
        if (context.isFirstIteration()) {
            return new LogisticRegressionParams();
        }
        this.weights = ((LogisticRegressionParams)context.getLastMasterResult()).getParameters();
        double[] gradients = new double[this.inputNum + 1];
        double finalError = 0.0;
        int size = 0;
        this.dataList.reOpen();
        for (Data data : this.dataList) {
            double error = this.sigmoid(data.inputs, this.weights) - data.outputs[0];
            finalError += error * error / 2.0;
            for (int i = 0; i < gradients.length; ++i) {
                int n = i;
                gradients[n] = gradients[n] + error * data.inputs[i];
            }
            ++size;
        }
        LOG.info("Iteration {} with error {}", (Object)context.getCurrentIteration(), (Object)(finalError / (double)size));
        return new LogisticRegressionParams(gradients, finalError / (double)size);
    }

    private double sigmoid(double[] inputs, double[] weights) {
        double value = 0.0;
        for (int i = 0; i < weights.length; ++i) {
            value += weights[i] * inputs[i];
        }
        return 1.0 / (1.0 + Math.exp(-value));
    }

    protected void postLoad(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
        this.dataList.switchState();
    }

    public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
        String line = ((Text)currentValue.getWritable()).toString();
        double[] inputData = new double[this.inputNum + 1];
        double[] outputData = new double[this.outputNum];
        int count = 0;
        int inputIndex = 0;
        int outputIndex = 0;
        inputData[inputIndex++] = 1.0;
        for (String unit : this.splitter.split((CharSequence)line)) {
            if (count < this.inputNum) {
                inputData[inputIndex++] = Double.valueOf(unit);
            } else {
                if (count < this.inputNum || count >= this.inputNum + this.outputNum) break;
                outputData[outputIndex++] = Double.valueOf(unit);
            }
            ++count;
        }
        this.dataList.append((Serializable)new Data(inputData, outputData));
    }

    private static class Data
    implements Serializable {
        private static final long serialVersionUID = 903201066309036170L;
        private final double[] inputs;
        private final double[] outputs;

        public Data(double[] inputs, double[] outputs) {
            this.inputs = inputs;
            this.outputs = outputs;
        }
    }
}

