/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.example.kmeans.KMeansMasterParams;
import ml.shifu.guagua.example.kmeans.KMeansWorkerParams;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansMaster
implements MasterComputable<KMeansMasterParams, KMeansWorkerParams> {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansMaster.class);

    public KMeansMasterParams compute(MasterContext<KMeansMasterParams, KMeansWorkerParams> context) {
        if (context.getWorkerResults() == null) {
            throw new NullPointerException("No worker results received in Master.");
        }
        if (context.getCurrentIteration() == 1) {
            return this.doFirstIteration(context);
        }
        return this.doOtherIterations(context);
    }

    private KMeansMasterParams doFirstIteration(MasterContext<KMeansMasterParams, KMeansWorkerParams> context) {
        ArrayList<double[]> allInitialCentriods = new ArrayList<double[]>();
        boolean initilized = false;
        int k = 0;
        int c = 0;
        for (KMeansWorkerParams workerResult : context.getWorkerResults()) {
            allInitialCentriods.addAll(workerResult.getPointList());
            if (initilized) continue;
            k = workerResult.getK();
            c = workerResult.getC();
        }
        if (allInitialCentriods.size() < k) {
            throw new GuaguaRuntimeException("Error: data size is smaller than k, please check your input and k settings.");
        }
        Collections.sort(allInitialCentriods, new Comparator<double[]>(){

            @Override
            public int compare(double[] o1, double[] o2) {
                double dist = KMeansMaster.this.distance(o1) - KMeansMaster.this.distance(o2);
                return Double.valueOf(dist).compareTo(0.0);
            }
        });
        ArrayList<double[]> initialCentriods = new ArrayList<double[]>(k);
        int step = allInitialCentriods.size() / k;
        for (int i = 0; i < k; ++i) {
            initialCentriods.add((double[])allInitialCentriods.get(i * step));
        }
        KMeansMasterParams masterResult = new KMeansMasterParams();
        masterResult.setK(k);
        masterResult.setC(c);
        masterResult.setPointList(initialCentriods);
        return masterResult;
    }

    private double distance(double[] record) {
        double sumSquare = 0.0;
        for (int i = 0; i < record.length; ++i) {
            sumSquare += record[i] * record[i];
        }
        return Math.sqrt(sumSquare);
    }

    private KMeansMasterParams doOtherIterations(MasterContext<KMeansMasterParams, KMeansWorkerParams> context) {
        LinkedList<double[]> sumAllList = new LinkedList<double[]>();
        LinkedList<Long> countAllList = new LinkedList<Long>();
        boolean initilized = false;
        int k = 0;
        int c = 0;
        for (KMeansWorkerParams workerResult : context.getWorkerResults()) {
            LOG.debug("Worker result: %s", (Object)workerResult);
            if (!initilized) {
                k = workerResult.getK();
                c = workerResult.getC();
            }
            for (int i = 0; i < k; ++i) {
                if (!initilized) {
                    sumAllList.add(new double[c]);
                    countAllList.add(0L);
                }
                long currCount = (Long)countAllList.get(i);
                countAllList.set(i, currCount + (long)workerResult.getCountList().get(i).intValue());
                double[] sumAll = (double[])sumAllList.get(i);
                for (int j = 0; j < c; ++j) {
                    int n = j;
                    sumAll[n] = sumAll[n] + workerResult.getPointList().get(i)[j];
                }
            }
            initilized = true;
        }
        LOG.debug("sumList: %s", sumAllList);
        LOG.debug("countList: %s", countAllList);
        LinkedList<double[]> meanList = new LinkedList<double[]>();
        for (int i = 0; i < k; ++i) {
            double[] means = new double[c];
            for (int j = 0; j < c; ++j) {
                means[j] = ((double[])sumAllList.get(i))[j] / (double)((Long)countAllList.get(i)).longValue();
            }
            meanList.add(means);
        }
        LOG.debug("meanList: %s", meanList);
        KMeansMasterParams masterResult = new KMeansMasterParams();
        masterResult.setK(k);
        masterResult.setC(c);
        masterResult.setPointList(meanList);
        return masterResult;
    }
}

