/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.kmeans;

import com.google.common.base.Splitter;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import ml.shifu.guagua.example.kmeans.KMeansMasterParams;
import ml.shifu.guagua.example.kmeans.KMeansWorkerParams;
import ml.shifu.guagua.example.kmeans.TaggedRecord;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.io.GuaguaRecordReader;
import ml.shifu.guagua.util.MemoryDiskList;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansWorker
extends AbstractWorkerComputable<KMeansMasterParams, KMeansWorkerParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansWorker.class);
    private MemoryDiskList<TaggedRecord> dataList;
    private int k;
    private int c;
    private String separator;

    public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
        this.setRecordReader((GuaguaRecordReader)new GuaguaLineRecordReader());
        this.getRecordReader().initialize(fileSplit);
    }

    public void init(WorkerContext<KMeansMasterParams, KMeansWorkerParams> context) {
        this.k = Integer.parseInt(context.getProps().getProperty("kmeans.k.number"));
        this.c = Integer.parseInt(context.getProps().getProperty("kmeans.column.number"));
        this.separator = context.getProps().getProperty("kmeans.data.seperator");
        double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.5"));
        String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", "tmp");
        this.dataList = new MemoryDiskList((long)((double)Runtime.getRuntime().maxMemory() * memoryFraction), tmpFolder + File.separator + System.currentTimeMillis());
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable(){

            @Override
            public void run() {
                KMeansWorker.this.dataList.close();
                KMeansWorker.this.dataList.clear();
            }
        }));
        context.setAttachment(this.dataList);
    }

    public KMeansWorkerParams doCompute(WorkerContext<KMeansMasterParams, KMeansWorkerParams> context) {
        if (context.getCurrentIteration() == 1) {
            return this.doFirstIteration(context);
        }
        this.dataList.reOpen();
        return this.doOtherIterations(context);
    }

    private KMeansWorkerParams doFirstIteration(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        KMeansWorkerParams workerResult = new KMeansWorkerParams();
        workerResult.setK(this.k);
        workerResult.setC(this.c);
        workerResult.setFirstIteration(true);
        int dataSize = this.dataList.size();
        ArrayList<double[]> pointList = new ArrayList<double[]>(dataSize);
        if (this.k >= dataSize) {
            for (TaggedRecord record : this.dataList) {
                pointList.add(this.toDouble(record));
            }
        } else {
            int m = dataSize / this.k;
            int i = 0;
            this.dataList.reOpen();
            for (TaggedRecord record : this.dataList) {
                if (i++ % m != 0) continue;
                pointList.add(this.toDouble(record));
            }
        }
        workerResult.setPointList(pointList);
        return workerResult;
    }

    private double[] toDouble(TaggedRecord record) {
        Double[] data = record.getRecord();
        double[] newData = new double[data.length];
        int i = 0;
        for (Double d : data) {
            newData[i] = d == null ? 0.0 : d;
        }
        return newData;
    }

    private KMeansWorkerParams doOtherIterations(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        List<double[]> centers = ((KMeansMasterParams)workerContext.getLastMasterResult()).getPointList();
        LOG.debug("Initial centers:%s", centers);
        LinkedList<double[]> sumList = new LinkedList<double[]>();
        LinkedList<Integer> countList = new LinkedList<Integer>();
        for (int i = 0; i < this.k; ++i) {
            sumList.add(new double[this.c]);
            countList.add(0);
        }
        for (TaggedRecord record : this.dataList) {
            int index = this.findClosedCenter(record.getRecord(), centers);
            record.setTag(index);
            countList.set(index, (Integer)countList.get(index) + 1);
            double[] sum = (double[])sumList.get(index);
            for (int i = 0; i < this.c; ++i) {
                int n = i;
                sum[n] = sum[n] + (record.getRecord()[i] == null ? 0.0 : record.getRecord()[i]);
            }
        }
        LOG.debug("sumList:%s", sumList);
        LOG.debug("countList:%s", countList);
        KMeansWorkerParams workerResult = new KMeansWorkerParams();
        workerResult.setK(this.k);
        workerResult.setC(this.c);
        workerResult.setFirstIteration(false);
        workerResult.setPointList(sumList);
        workerResult.setCountList(countList);
        return workerResult;
    }

    protected void postLoad(WorkerContext<KMeansMasterParams, KMeansWorkerParams> context) {
        this.dataList.switchState();
    }

    private int findClosedCenter(Double[] record, List<double[]> centers) {
        int index = 0;
        double minDist = this.distance(record, centers.get(0));
        for (int i = 1; i < centers.size(); ++i) {
            double distance = this.distance(record, centers.get(i));
            if (!(distance < minDist)) continue;
            index = i;
        }
        return index;
    }

    private double distance(Double[] record, double[] center) {
        double denominator = 0.0;
        for (int i = 0; i < center.length; ++i) {
            denominator += record[i] == null ? 0.0 : record[i] * center[i];
        }
        double sqW1 = 0.0;
        double sqW2 = 0.0;
        for (int i = 0; i < center.length; ++i) {
            sqW1 += record[i] == null ? 0.0 : record[i] * record[i];
            sqW2 += center[i] * center[i];
        }
        return denominator / (Math.sqrt(sqW1) * Math.sqrt(sqW2));
    }

    public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        String line = ((Text)currentValue.getWritable()).toString();
        Double[] record = new Double[this.c];
        int i = 0;
        for (String input : Splitter.on((String)this.separator).split((CharSequence)line)) {
            try {
                record[i++] = Double.parseDouble(input);
            }
            catch (NumberFormatException e) {
                record[i++] = null;
            }
        }
        this.dataList.append((Serializable)new TaggedRecord(record));
    }
}

