/*
 * Decompiled with CFR 0.152.
 */
package weka.clusterers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.clusterers.NumberOfClustersRequestable;
import weka.clusterers.RandomizableDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class EM
extends RandomizableDensityBasedClusterer
implements NumberOfClustersRequestable,
WeightedInstancesHandler {
    static final long serialVersionUID = 8348181483812829475L;
    private Estimator[][] m_modelPrev;
    private double[][][] m_modelNormalPrev;
    private double[] m_priorsPrev;
    private Estimator[][] m_model;
    private double[][][] m_modelNormal;
    private double m_minStdDev = 1.0E-6;
    private double[] m_minStdDevPerAtt;
    private double[][] m_weights;
    private double[] m_priors;
    private Instances m_theInstances = null;
    private int m_num_clusters;
    private int m_initialNumClusters;
    private int m_upperBoundNumClustersCV = -1;
    private int m_num_attribs;
    private int m_num_instances;
    private int m_max_iterations;
    private double[] m_minValues;
    private double[] m_maxValues;
    private Random m_rr;
    private boolean m_verbose;
    private ReplaceMissingValues m_replaceMissing;
    private boolean m_displayModelInOldFormat;
    protected int m_executionSlots = 1;
    protected transient ExecutorService m_executorPool;
    protected boolean m_training;
    protected int m_iterationsPerformed;
    protected double m_minLogLikelihoodImprovementIterating = 1.0E-6;
    protected double m_minLogLikelihoodImprovementCV = 1.0E-6;
    protected int m_cvFolds = 10;
    protected int m_NumKMeansRuns = 10;
    private static double m_normConst = Math.log(Math.sqrt(Math.PI * 2));

    public String globalInfo() {
        return "Simple EM (expectation maximisation) class.\n\nEM assigns a probability distribution to each instance which indicates the probability of it belonging to each of the clusters. EM can decide how many clusters to create by cross validation, or you may specify apriori how many clusters to generate.\n\nThe cross validation performed to determine the number of clusters is done in the following steps:\n1. the number of clusters is set to 1\n2. the training set is split randomly into 10 folds.\n3. EM is performed 10 times using the 10 folds the usual CV way.\n4. the loglikelihood is averaged over all 10 results.\n5. if loglikelihood has increased the number of clusters is increased by 1 and the program continues at step 2. \n\nThe number of folds is fixed to 10, as long as the number of instances in the training set is not smaller 10. If this is the case the number of folds is set equal to the number of instances.\n\nMissing values are globally replaced with ReplaceMissingValues.";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\tnumber of clusters. If omitted or -1 specified, then \n\tcross validation is used to select the number of clusters.", "N", 1, "-N <num>"));
        result.addElement(new Option("\tNumber of folds to use when cross-validating to find the best number of clusters.", "X", 1, "-X <num>"));
        result.addElement(new Option("\tNumber of runs of k-means to perform.\n\t(default 10)", "K", 1, "-K <num>"));
        result.addElement(new Option("\tMaximum number of clusters to consider during cross-validation. If omitted or -1 specified, then \n\tthere is no upper limit on the number of clusters.", "max", 1, "-max <num>"));
        result.addElement(new Option("\tMinimum improvement in cross-validated log likelihood required\n\tto consider increasing the number of clusters.\n\t(default 1e-6)", "ll-cv", 1, "-ll-cv <num>"));
        result.addElement(new Option("\tmax iterations.\n\t(default 100)", "I", 1, "-I <num>"));
        result.addElement(new Option("\tMinimum improvement in log likelihood required\n\tto perform another iteration of the E and M steps.\n\t(default 1e-6)", "ll-iter", 1, "-ll-iter <num>"));
        result.addElement(new Option("\tverbose.", "V", 0, "-V"));
        result.addElement(new Option("\tminimum allowable standard deviation for normal density\n\tcomputation\n\t(default 1e-6)", "M", 1, "-M <num>"));
        result.addElement(new Option("\tDisplay model in old format (good when there are many clusters)\n", "O", 0, "-O"));
        result.addElement(new Option("\tNumber of execution slots.\n\t(default 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>"));
        result.addAll(Collections.list(super.listOptions()));
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.resetOptions();
        this.setDebug(Utils.getFlag('V', options));
        String optionString = Utils.getOption('I', options);
        if (optionString.length() != 0) {
            this.setMaxIterations(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('X', options)).length() > 0) {
            this.setNumFolds(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption("ll-iter", options)).length() > 0) {
            this.setMinLogLikelihoodImprovementIterating(Double.parseDouble(optionString));
        }
        if ((optionString = Utils.getOption("ll-cv", options)).length() > 0) {
            this.setMinLogLikelihoodImprovementCV(Double.parseDouble(optionString));
        }
        if ((optionString = Utils.getOption('N', options)).length() != 0) {
            this.setNumClusters(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption("max", options)).length() > 0) {
            this.setMaximumNumberOfClusters(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('M', options)).length() != 0) {
            this.setMinStdDev(new Double(optionString));
        }
        if ((optionString = Utils.getOption('K', options)).length() != 0) {
            this.setNumKMeansRuns(new Integer(optionString));
        }
        this.setDisplayModelInOldFormat(Utils.getFlag('O', options));
        String slotsS = Utils.getOption("num-slots", options);
        if (slotsS.length() > 0) {
            this.setNumExecutionSlots(Integer.parseInt(slotsS));
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    public String numKMeansRunsTipText() {
        return "The number of runs of k-means to perform.";
    }

    public int getNumKMeansRuns() {
        return this.m_NumKMeansRuns;
    }

    public void setNumKMeansRuns(int intValue) {
        this.m_NumKMeansRuns = intValue;
    }

    public String numFoldsTipText() {
        return "The number of folds to use when cross-validating to find the best number of clusters (default = 10)";
    }

    public void setNumFolds(int folds) {
        this.m_cvFolds = folds;
    }

    public int getNumFolds() {
        return this.m_cvFolds;
    }

    public String minLogLikelihoodImprovementCVTipText() {
        return "The minimum improvement in cross-validated log likelihood required in order to consider increasing the number of clusters when cross-validiting to find the best number of clusters";
    }

    public void setMinLogLikelihoodImprovementCV(double min) {
        this.m_minLogLikelihoodImprovementCV = min;
    }

    public double getMinLogLikelihoodImprovementCV() {
        return this.m_minLogLikelihoodImprovementCV;
    }

    public String minLogLikelihoodImprovementIteratingTipText() {
        return "The minimum improvement in log likelihood required to perform another iteration of the E and M steps";
    }

    public void setMinLogLikelihoodImprovementIterating(double min) {
        this.m_minLogLikelihoodImprovementIterating = min;
    }

    public double getMinLogLikelihoodImprovementIterating() {
        return this.m_minLogLikelihoodImprovementIterating;
    }

    public String numExecutionSlotsTipText() {
        return "The number of execution slots (threads) to use. Set equal to the number of available cpu/cores";
    }

    public void setNumExecutionSlots(int slots) {
        this.m_executionSlots = slots;
    }

    public int getNumExecutionSlots() {
        return this.m_executionSlots;
    }

    public String displayModelInOldFormatTipText() {
        return "Use old format for model output. The old format is better when there are many clusters. The new format is better when there are fewer clusters and many attributes.";
    }

    public void setDisplayModelInOldFormat(boolean d) {
        this.m_displayModelInOldFormat = d;
    }

    public boolean getDisplayModelInOldFormat() {
        return this.m_displayModelInOldFormat;
    }

    public String minStdDevTipText() {
        return "set minimum allowable standard deviation";
    }

    public void setMinStdDev(double m) {
        this.m_minStdDev = m;
    }

    public void setMinStdDevPerAtt(double[] m) {
        this.m_minStdDevPerAtt = m;
    }

    public double getMinStdDev() {
        return this.m_minStdDev;
    }

    public String numClustersTipText() {
        return "set number of clusters. -1 to select number of clusters automatically by cross validation.";
    }

    @Override
    public void setNumClusters(int n) throws Exception {
        if (n == 0) {
            throw new Exception("Number of clusters must be > 0. (or -1 to select by cross validation).");
        }
        if (n < 0) {
            this.m_num_clusters = -1;
            this.m_initialNumClusters = -1;
        } else {
            this.m_num_clusters = n;
            this.m_initialNumClusters = n;
        }
    }

    public int getNumClusters() {
        return this.m_initialNumClusters;
    }

    public void setMaximumNumberOfClusters(int n) {
        this.m_upperBoundNumClustersCV = n;
    }

    public int getMaximumNumberOfClusters() {
        return this.m_upperBoundNumClustersCV;
    }

    public String maximumNumberOfClustersTipText() {
        return "The maximum number of clusters to consider during cross-validation to select the best number of clusters";
    }

    public String maxIterationsTipText() {
        return "maximum number of iterations";
    }

    public void setMaxIterations(int i) throws Exception {
        if (i < 1) {
            throw new Exception("Maximum number of iterations must be > 0!");
        }
        this.m_max_iterations = i;
    }

    public int getMaxIterations() {
        return this.m_max_iterations;
    }

    @Override
    public String debugTipText() {
        return "If set to true, clusterer may output additional info to the console.";
    }

    @Override
    public void setDebug(boolean v) {
        this.m_verbose = v;
    }

    @Override
    public boolean getDebug() {
        return this.m_verbose;
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-I");
        result.add("" + this.m_max_iterations);
        result.add("-N");
        result.add("" + this.getNumClusters());
        result.add("-X");
        result.add("" + this.getNumFolds());
        result.add("-max");
        result.add("" + this.getMaximumNumberOfClusters());
        result.add("-ll-cv");
        result.add("" + this.getMinLogLikelihoodImprovementCV());
        result.add("-ll-iter");
        result.add("" + this.getMinLogLikelihoodImprovementIterating());
        result.add("-M");
        result.add("" + this.getMinStdDev());
        result.add("-K");
        result.add("" + this.getNumKMeansRuns());
        if (this.m_displayModelInOldFormat) {
            result.add("-O");
        }
        result.add("-num-slots");
        result.add("" + this.getNumExecutionSlots());
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    private void EM_Init(Instances inst) throws Exception {
        int j;
        int i;
        SimpleKMeans bestK = null;
        double bestSqE = Double.MAX_VALUE;
        for (i = 0; i < this.m_NumKMeansRuns; ++i) {
            SimpleKMeans sk = new SimpleKMeans();
            sk.setSeed(this.m_rr.nextInt());
            sk.setNumClusters(this.m_num_clusters);
            sk.setNumExecutionSlots(this.m_executionSlots);
            sk.setDisplayStdDevs(true);
            sk.setDoNotCheckCapabilities(true);
            sk.setDontReplaceMissingValues(true);
            sk.buildClusterer(inst);
            if (!(sk.getSquaredError() < bestSqE)) continue;
            bestSqE = sk.getSquaredError();
            bestK = sk;
        }
        this.m_num_clusters = bestK.numberOfClusters();
        this.m_weights = new double[inst.numInstances()][this.m_num_clusters];
        this.m_model = new DiscreteEstimator[this.m_num_clusters][this.m_num_attribs];
        this.m_modelNormal = new double[this.m_num_clusters][this.m_num_attribs][3];
        this.m_priors = new double[this.m_num_clusters];
        this.m_modelPrev = new DiscreteEstimator[this.m_num_clusters][this.m_num_attribs];
        this.m_modelNormalPrev = new double[this.m_num_clusters][this.m_num_attribs][3];
        this.m_priorsPrev = new double[this.m_num_clusters];
        Instances centers = bestK.getClusterCentroids();
        Instances stdD = bestK.getClusterStandardDevs();
        double[][][] nominalCounts = bestK.getClusterNominalCounts();
        double[] clusterSizes = bestK.getClusterSizes();
        for (i = 0; i < this.m_num_clusters; ++i) {
            Instance center = centers.instance(i);
            for (j = 0; j < this.m_num_attribs; ++j) {
                if (inst.attribute(j).isNominal()) {
                    this.m_model[i][j] = new DiscreteEstimator(this.m_theInstances.attribute(j).numValues(), true);
                    for (int k = 0; k < inst.attribute(j).numValues(); ++k) {
                        this.m_model[i][j].addValue(k, nominalCounts[i][j][k]);
                    }
                    continue;
                }
                double minStdD = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[j] : this.m_minStdDev;
                this.m_modelNormal[i][j][0] = center.value(j);
                double stdv = stdD.instance(i).value(j);
                if (stdv < minStdD) {
                    stdv = Math.sqrt(inst.variance(j));
                    if (Double.isInfinite(stdv)) {
                        stdv = minStdD;
                    }
                    if (stdv < minStdD) {
                        stdv = minStdD;
                    }
                }
                if (stdv <= 0.0 || Double.isNaN(stdv)) {
                    stdv = this.m_minStdDev;
                }
                this.m_modelNormal[i][j][1] = stdv;
                this.m_modelNormal[i][j][2] = 1.0;
            }
        }
        for (j = 0; j < this.m_num_clusters; ++j) {
            this.m_priors[j] = clusterSizes[j];
        }
        Utils.normalize(this.m_priors);
    }

    private void estimate_priors(Instances inst) throws Exception {
        int i;
        for (i = 0; i < this.m_num_clusters; ++i) {
            this.m_priorsPrev[i] = this.m_priors[i];
            this.m_priors[i] = 0.0;
        }
        for (i = 0; i < inst.numInstances(); ++i) {
            for (int j = 0; j < this.m_num_clusters; ++j) {
                int n = j;
                this.m_priors[n] = this.m_priors[n] + inst.instance(i).weight() * this.m_weights[i][j];
            }
        }
        Utils.normalize(this.m_priors);
    }

    private double logNormalDens(double x, double mean, double stdDev) {
        double diff = x - mean;
        return -(diff * diff / (2.0 * stdDev * stdDev)) - m_normConst - Math.log(stdDev);
    }

    private void new_estimators() {
        for (int i = 0; i < this.m_num_clusters; ++i) {
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (this.m_theInstances.attribute(j).isNominal()) {
                    this.m_modelPrev[i][j] = this.m_model[i][j];
                    this.m_model[i][j] = new DiscreteEstimator(this.m_theInstances.attribute(j).numValues(), true);
                    continue;
                }
                this.m_modelNormalPrev[i][j][0] = this.m_modelNormal[i][j][0];
                this.m_modelNormalPrev[i][j][1] = this.m_modelNormal[i][j][1];
                this.m_modelNormalPrev[i][j][2] = this.m_modelNormal[i][j][2];
                this.m_modelNormal[i][j][2] = 0.0;
                this.m_modelNormal[i][j][1] = 0.0;
                this.m_modelNormal[i][j][0] = 0.0;
            }
        }
    }

    protected void startExecutorPool() {
        if (this.m_executorPool != null) {
            this.m_executorPool.shutdownNow();
        }
        this.m_executorPool = Executors.newFixedThreadPool(this.m_executionSlots);
    }

    private void M_reEstimate(Instances inst) {
        for (int i = 0; i < this.m_num_clusters; ++i) {
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (inst.attribute(j).isNominal()) continue;
                if (this.m_modelNormal[i][j][2] <= 0.0) {
                    this.m_modelNormal[i][j][1] = Double.MAX_VALUE;
                    this.m_modelNormal[i][j][0] = this.m_minStdDev;
                    continue;
                }
                this.m_modelNormal[i][j][1] = (this.m_modelNormal[i][j][1] - this.m_modelNormal[i][j][0] * this.m_modelNormal[i][j][0] / this.m_modelNormal[i][j][2]) / this.m_modelNormal[i][j][2];
                if (this.m_modelNormal[i][j][1] < 0.0) {
                    this.m_modelNormal[i][j][1] = 0.0;
                }
                double minStdD = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[j] : this.m_minStdDev;
                this.m_modelNormal[i][j][1] = Math.sqrt(this.m_modelNormal[i][j][1]);
                if (this.m_modelNormal[i][j][1] <= minStdD) {
                    this.m_modelNormal[i][j][1] = Math.sqrt(inst.variance(j));
                    if (this.m_modelNormal[i][j][1] <= minStdD) {
                        this.m_modelNormal[i][j][1] = minStdD;
                    }
                }
                if (this.m_modelNormal[i][j][1] <= 0.0) {
                    this.m_modelNormal[i][j][1] = this.m_minStdDev;
                }
                if (Double.isInfinite(this.m_modelNormal[i][j][1])) {
                    this.m_modelNormal[i][j][1] = this.m_minStdDev;
                }
                double[] dArray = this.m_modelNormal[i][j];
                dArray[0] = dArray[0] / this.m_modelNormal[i][j][2];
            }
        }
    }

    private void M(Instances inst) throws Exception {
        this.new_estimators();
        this.estimate_priors(inst);
        for (int l = 0; l < inst.numInstances(); ++l) {
            Instance in = inst.instance(l);
            for (int i = 0; i < this.m_num_clusters; ++i) {
                for (int j = 0; j < this.m_num_attribs; ++j) {
                    if (inst.attribute(j).isNominal()) {
                        this.m_model[i][j].addValue(in.value(j), in.weight() * this.m_weights[l][i]);
                        continue;
                    }
                    double[] dArray = this.m_modelNormal[i][j];
                    dArray[0] = dArray[0] + in.value(j) * in.weight() * this.m_weights[l][i];
                    double[] dArray2 = this.m_modelNormal[i][j];
                    dArray2[2] = dArray2[2] + in.weight() * this.m_weights[l][i];
                    double[] dArray3 = this.m_modelNormal[i][j];
                    dArray3[1] = dArray3[1] + in.value(j) * in.value(j) * in.weight() * this.m_weights[l][i];
                }
            }
        }
        this.M_reEstimate(inst);
    }

    private double E(Instances inst, boolean change_weights) throws Exception {
        double loglk = 0.0;
        double sOW = 0.0;
        for (int l = 0; l < inst.numInstances(); ++l) {
            Instance in = inst.instance(l);
            loglk += in.weight() * this.logDensityForInstance(in);
            sOW += in.weight();
            if (!change_weights) continue;
            this.m_weights[l] = this.distributionForInstance(in);
        }
        if (sOW <= 0.0) {
            return 0.0;
        }
        return loglk / sOW;
    }

    public EM() {
        this.m_SeedDefault = 100;
        this.resetOptions();
    }

    protected void resetOptions() {
        this.m_minStdDev = 1.0E-6;
        this.m_max_iterations = 100;
        this.m_Seed = this.m_SeedDefault;
        this.m_num_clusters = -1;
        this.m_initialNumClusters = -1;
        this.m_verbose = false;
        this.m_minLogLikelihoodImprovementIterating = 1.0E-6;
        this.m_minLogLikelihoodImprovementCV = 1.0E-6;
        this.m_executionSlots = 1;
        this.m_cvFolds = 10;
    }

    public double[][][] getClusterModelsNumericAtts() {
        return this.m_modelNormal;
    }

    public double[] getClusterPriors() {
        return this.m_priors;
    }

    public String toString() {
        int i;
        if (this.m_displayModelInOldFormat) {
            return this.toStringOriginal();
        }
        if (this.m_priors == null) {
            return "No clusterer built yet!";
        }
        StringBuffer temp = new StringBuffer();
        temp.append("\nEM\n==\n");
        if (this.m_initialNumClusters == -1) {
            temp.append("\nNumber of clusters selected by cross validation: " + this.m_num_clusters + "\n");
        } else {
            temp.append("\nNumber of clusters: " + this.m_num_clusters + "\n");
        }
        temp.append("Number of iterations performed: " + this.m_iterationsPerformed + "\n");
        int maxWidth = 0;
        int maxAttWidth = 0;
        for (i = 0; i < this.m_num_attribs; ++i) {
            Attribute a2 = this.m_theInstances.attribute(i);
            if (a2.name().length() > maxAttWidth) {
                maxAttWidth = this.m_theInstances.attribute(i).name().length();
            }
            if (!a2.isNominal()) continue;
            for (int j = 0; j < a2.numValues(); ++j) {
                String val = a2.value(j) + "  ";
                if (val.length() <= maxAttWidth) continue;
                maxAttWidth = val.length();
            }
        }
        for (i = 0; i < this.m_num_clusters; ++i) {
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (this.m_theInstances.attribute(j).isNumeric()) {
                    double stdD;
                    double width;
                    double mean = Math.log(Math.abs(this.m_modelNormal[i][j][0])) / Math.log(10.0);
                    double d = width = mean > (stdD = Math.log(Math.abs(this.m_modelNormal[i][j][1])) / Math.log(10.0)) ? mean : stdD;
                    if (width < 0.0) {
                        width = 1.0;
                    }
                    if ((int)(width += 6.0) <= maxWidth) continue;
                    maxWidth = (int)width;
                    continue;
                }
                DiscreteEstimator d = (DiscreteEstimator)this.m_model[i][j];
                for (int k = 0; k < d.getNumSymbols(); ++k) {
                    String size = Utils.doubleToString(d.getCount(k), maxWidth, 4).trim();
                    if (size.length() <= maxWidth) continue;
                    maxWidth = size.length();
                }
                int sum = Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim().length();
                if (sum <= maxWidth) continue;
                maxWidth = sum;
            }
        }
        if (maxAttWidth < "Attribute".length()) {
            maxAttWidth = "Attribute".length();
        }
        temp.append("\n\n");
        temp.append(this.pad("Cluster", " ", (maxAttWidth += 2) + maxWidth + 1 - "Cluster".length(), true));
        temp.append("\n");
        temp.append(this.pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));
        for (i = 0; i < this.m_num_clusters; ++i) {
            String classL = "" + i;
            temp.append(this.pad(classL, " ", maxWidth + 1 - classL.length(), true));
        }
        temp.append("\n");
        temp.append(this.pad("", " ", maxAttWidth, true));
        for (i = 0; i < this.m_num_clusters; ++i) {
            String priorP = Utils.doubleToString(this.m_priors[i], maxWidth, 2).trim();
            priorP = "(" + priorP + ")";
            temp.append(this.pad(priorP, " ", maxWidth + 1 - priorP.length(), true));
        }
        temp.append("\n");
        temp.append(this.pad("", "=", maxAttWidth + maxWidth * this.m_num_clusters + this.m_num_clusters + 1, true));
        temp.append("\n");
        for (i = 0; i < this.m_num_attribs; ++i) {
            String attName = this.m_theInstances.attribute(i).name();
            temp.append(attName + "\n");
            if (this.m_theInstances.attribute(i).isNumeric()) {
                String meanL = "  mean";
                temp.append(this.pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false));
                for (int j = 0; j < this.m_num_clusters; ++j) {
                    String mean = Utils.doubleToString(this.m_modelNormal[j][i][0], maxWidth, 4).trim();
                    temp.append(this.pad(mean, " ", maxWidth + 1 - mean.length(), true));
                }
                temp.append("\n");
                String stdDevL = "  std. dev.";
                temp.append(this.pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false));
                for (int j = 0; j < this.m_num_clusters; ++j) {
                    String stdDev = Utils.doubleToString(this.m_modelNormal[j][i][1], maxWidth, 4).trim();
                    temp.append(this.pad(stdDev, " ", maxWidth + 1 - stdDev.length(), true));
                }
                temp.append("\n\n");
                continue;
            }
            Attribute a3 = this.m_theInstances.attribute(i);
            for (int j = 0; j < a3.numValues(); ++j) {
                String val = "  " + a3.value(j);
                temp.append(this.pad(val, " ", maxAttWidth + 1 - val.length(), false));
                for (int k = 0; k < this.m_num_clusters; ++k) {
                    DiscreteEstimator d = (DiscreteEstimator)this.m_model[k][i];
                    String count = Utils.doubleToString(d.getCount(j), maxWidth, 4).trim();
                    temp.append(this.pad(count, " ", maxWidth + 1 - count.length(), true));
                }
                temp.append("\n");
            }
            String total = "  [total]";
            temp.append(this.pad(total, " ", maxAttWidth + 1 - total.length(), false));
            for (int k = 0; k < this.m_num_clusters; ++k) {
                DiscreteEstimator d = (DiscreteEstimator)this.m_model[k][i];
                String count = Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim();
                temp.append(this.pad(count, " ", maxWidth + 1 - count.length(), true));
            }
            temp.append("\n");
        }
        return temp.toString();
    }

    private String pad(String source, String padChar, int length, boolean leftPad) {
        StringBuffer temp = new StringBuffer();
        if (leftPad) {
            for (int i = 0; i < length; ++i) {
                temp.append(padChar);
            }
            temp.append(source);
        } else {
            temp.append(source);
            for (int i = 0; i < length; ++i) {
                temp.append(padChar);
            }
        }
        return temp.toString();
    }

    protected String toStringOriginal() {
        if (this.m_priors == null) {
            return "No clusterer built yet!";
        }
        StringBuffer temp = new StringBuffer();
        temp.append("\nEM\n==\n");
        if (this.m_initialNumClusters == -1) {
            temp.append("\nNumber of clusters selected by cross validation: " + this.m_num_clusters + "\n");
        } else {
            temp.append("\nNumber of clusters: " + this.m_num_clusters + "\n");
        }
        for (int j = 0; j < this.m_num_clusters; ++j) {
            temp.append("\nCluster: " + j + " Prior probability: " + Utils.doubleToString(this.m_priors[j], 4) + "\n\n");
            for (int i = 0; i < this.m_num_attribs; ++i) {
                temp.append("Attribute: " + this.m_theInstances.attribute(i).name() + "\n");
                if (this.m_theInstances.attribute(i).isNominal()) {
                    if (this.m_model[j][i] == null) continue;
                    temp.append(this.m_model[j][i].toString());
                    continue;
                }
                temp.append("Normal Distribution. Mean = " + Utils.doubleToString(this.m_modelNormal[j][i][0], 4) + " StdDev = " + Utils.doubleToString(this.m_modelNormal[j][i][1], 4) + "\n");
            }
        }
        return temp.toString();
    }

    private void EM_Report(Instances inst) {
        int j;
        System.out.println("======================================");
        for (j = 0; j < this.m_num_clusters; ++j) {
            for (int i = 0; i < this.m_num_attribs; ++i) {
                System.out.println("Clust: " + j + " att: " + i + "\n");
                if (this.m_theInstances.attribute(i).isNominal()) {
                    if (this.m_model[j][i] == null) continue;
                    System.out.println(this.m_model[j][i].toString());
                    continue;
                }
                System.out.println("Normal Distribution. Mean = " + Utils.doubleToString(this.m_modelNormal[j][i][0], 8, 4) + " StandardDev = " + Utils.doubleToString(this.m_modelNormal[j][i][1], 8, 4) + " WeightSum = " + Utils.doubleToString(this.m_modelNormal[j][i][2], 8, 4));
            }
        }
        for (int l = 0; l < inst.numInstances(); ++l) {
            int m = Utils.maxIndex(this.m_weights[l]);
            System.out.print("Inst " + Utils.doubleToString(l, 5, 0) + " Class " + m + "\t");
            for (j = 0; j < this.m_num_clusters; ++j) {
                System.out.print(Utils.doubleToString(this.m_weights[l][j], 7, 5) + "  ");
            }
            System.out.println();
        }
    }

    private void CVClusters() throws Exception {
        double CVLogLikely = -1.7976931348623157E308;
        boolean CVincreased = true;
        this.m_num_clusters = 1;
        int upperBoundMaxClusters = this.m_upperBoundNumClustersCV > 0 ? this.m_upperBoundNumClustersCV : Integer.MAX_VALUE;
        int num_clusters = this.m_num_clusters;
        int numFolds = this.m_theInstances.numInstances() < this.m_cvFolds ? this.m_theInstances.numInstances() : this.m_cvFolds;
        boolean ok = true;
        int seed = this.getSeed();
        int restartCount = 0;
        block4: while (CVincreased && num_clusters <= upperBoundMaxClusters) {
            CVincreased = false;
            Random cvr = new Random(this.getSeed());
            Instances trainCopy = new Instances(this.m_theInstances);
            trainCopy.randomize(cvr);
            double templl = 0.0;
            for (int i = 0; i < numFolds; ++i) {
                double tll;
                Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
                if (num_clusters > cvTrain.numInstances()) break block4;
                Instances cvTest = trainCopy.testCV(numFolds, i);
                this.m_rr = new Random(seed);
                for (int z = 0; z < 10; ++z) {
                    this.m_rr.nextDouble();
                }
                this.m_num_clusters = num_clusters;
                this.EM_Init(cvTrain);
                try {
                    this.iterate(cvTrain, false);
                }
                catch (Exception ex) {
                    ex.printStackTrace();
                    ++seed;
                    ok = false;
                    if (++restartCount <= 5) break;
                    break block4;
                }
                try {
                    tll = this.E(cvTest, false);
                }
                catch (Exception ex) {
                    ex.printStackTrace();
                    ++seed;
                    ok = false;
                    if (++restartCount <= 5) break;
                    break block4;
                }
                if (this.m_verbose) {
                    System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll);
                }
                templl += tll;
            }
            if (!ok) continue;
            restartCount = 0;
            seed = this.getSeed();
            templl /= (double)numFolds;
            if (this.m_verbose) {
                System.out.println("=================================================\n# clust: " + num_clusters + " Mean Loglikely: " + templl + "\n================================" + "=================");
            }
            if (!(templl - CVLogLikely > this.m_minLogLikelihoodImprovementCV)) continue;
            CVLogLikely = templl;
            CVincreased = true;
            ++num_clusters;
        }
        if (this.m_verbose) {
            System.out.println("Number of clusters: " + (num_clusters - 1));
        }
        this.m_num_clusters = num_clusters - 1;
    }

    @Override
    public int numberOfClusters() throws Exception {
        if (this.m_num_clusters == -1) {
            throw new Exception("Haven't generated any clusters!");
        }
        return this.m_num_clusters;
    }

    private void updateMinMax(Instance instance) {
        for (int j = 0; j < this.m_theInstances.numAttributes(); ++j) {
            if (instance.value(j) < this.m_minValues[j]) {
                this.m_minValues[j] = instance.value(j);
                continue;
            }
            if (!(instance.value(j) > this.m_maxValues[j])) continue;
            this.m_maxValues[j] = instance.value(j);
        }
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = new SimpleKMeans().getCapabilities();
        result.setOwner(this);
        return result;
    }

    @Override
    public void buildClusterer(Instances data) throws Exception {
        int i;
        this.m_training = true;
        this.getCapabilities().testWithFail(data);
        this.m_replaceMissing = new ReplaceMissingValues();
        Instances instances = new Instances(data);
        instances.setClassIndex(-1);
        this.m_replaceMissing.setInputFormat(instances);
        data = Filter.useFilter(instances, this.m_replaceMissing);
        instances = null;
        this.m_theInstances = data;
        this.m_minValues = new double[this.m_theInstances.numAttributes()];
        this.m_maxValues = new double[this.m_theInstances.numAttributes()];
        for (i = 0; i < this.m_theInstances.numAttributes(); ++i) {
            this.m_minValues[i] = Double.MAX_VALUE;
            this.m_maxValues[i] = -1.7976931348623157E308;
        }
        for (i = 0; i < this.m_theInstances.numInstances(); ++i) {
            this.updateMinMax(this.m_theInstances.instance(i));
        }
        this.doEM();
        this.m_theInstances = new Instances(this.m_theInstances, 0);
        this.m_training = false;
    }

    @Override
    public double[] clusterPriors() {
        double[] n = new double[this.m_priors.length];
        System.arraycopy(this.m_priors, 0, n, 0, n.length);
        return n;
    }

    @Override
    public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {
        double[] wghts = new double[this.m_num_clusters];
        if (!this.m_training) {
            this.m_replaceMissing.input(inst);
            inst = this.m_replaceMissing.output();
        }
        for (int i = 0; i < this.m_num_clusters; ++i) {
            double logprob = 0.0;
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (inst.attribute(j).isNominal()) {
                    logprob += Math.log(this.m_model[i][j].getProbability(inst.value(j)));
                    continue;
                }
                logprob += this.logNormalDens(inst.value(j), this.m_modelNormal[i][j][0], this.m_modelNormal[i][j][1]);
            }
            wghts[i] = logprob;
        }
        return wghts;
    }

    private void doEM() throws Exception {
        int i;
        if (this.m_verbose) {
            System.out.println("Seed: " + this.getSeed());
        }
        this.m_rr = new Random(this.getSeed());
        for (i = 0; i < 10; ++i) {
            this.m_rr.nextDouble();
        }
        this.m_num_instances = this.m_theInstances.numInstances();
        this.m_num_attribs = this.m_theInstances.numAttributes();
        if (this.m_verbose) {
            System.out.println("Number of instances: " + this.m_num_instances + "\nNumber of atts: " + this.m_num_attribs + "\n");
        }
        this.startExecutorPool();
        if (this.m_initialNumClusters == -1) {
            if (this.m_theInstances.numInstances() > 9) {
                this.CVClusters();
                this.m_rr = new Random(this.getSeed());
                for (i = 0; i < 10; ++i) {
                    this.m_rr.nextDouble();
                }
            } else {
                this.m_num_clusters = 1;
            }
        }
        this.EM_Init(this.m_theInstances);
        double loglikely = this.iterate(this.m_theInstances, this.m_verbose);
        if (this.m_Debug) {
            System.err.println("Current log-likelihood: " + loglikely);
        }
        this.m_executorPool.shutdown();
    }

    protected double launchESteps(Instances inst) throws Exception {
        int i;
        int numPerTask = inst.numInstances() / this.m_executionSlots;
        double eStepLogL = 0.0;
        double eStepSow = 0.0;
        if (this.m_executionSlots <= 1 || inst.numInstances() < 2 * this.m_executionSlots) {
            return this.E(inst, true);
        }
        ArrayList<Future<double[]>> results = new ArrayList<Future<double[]>>();
        for (i = 0; i < this.m_executionSlots; ++i) {
            int start = i * numPerTask;
            int end = start + numPerTask;
            if (i == this.m_executionSlots - 1) {
                end = inst.numInstances();
            }
            ETask newTask = new ETask(inst, start, end, true);
            Future<double[]> futureE = this.m_executorPool.submit(newTask);
            results.add(futureE);
        }
        for (i = 0; i < results.size(); ++i) {
            double[] r = (double[])((Future)results.get(i)).get();
            eStepLogL += r[0];
            eStepSow += r[1];
        }
        return eStepLogL /= eStepSow;
    }

    protected void launchMSteps(Instances inst) throws Exception {
        if (this.m_executionSlots <= 1 || inst.numInstances() < 2 * this.m_executionSlots) {
            this.M(inst);
            return;
        }
        this.new_estimators();
        this.estimate_priors(inst);
        int numPerTask = inst.numInstances() / this.m_executionSlots;
        ArrayList<Future<MTask>> results = new ArrayList<Future<MTask>>();
        for (int i = 0; i < this.m_executionSlots; ++i) {
            int n = i * numPerTask;
            int end = n + numPerTask;
            if (i == this.m_executionSlots - 1) {
                end = inst.numInstances();
            }
            DiscreteEstimator[][] model = new DiscreteEstimator[this.m_num_clusters][this.m_num_attribs];
            double[][][] normal = new double[this.m_num_clusters][this.m_num_attribs][3];
            for (int ii = 0; ii < this.m_num_clusters; ++ii) {
                for (int j = 0; j < this.m_num_attribs; ++j) {
                    if (this.m_theInstances.attribute(j).isNominal()) {
                        model[ii][j] = new DiscreteEstimator(this.m_theInstances.attribute(j).numValues(), false);
                        continue;
                    }
                    normal[ii][j][2] = 0.0;
                    normal[ii][j][1] = 0.0;
                    normal[ii][j][0] = 0.0;
                }
            }
            MTask newTask = new MTask(inst, n, end, model, normal);
            Future<MTask> futureM = this.m_executorPool.submit(newTask);
            results.add(futureM);
        }
        for (Future future : results) {
            MTask m = (MTask)future.get();
            for (int i = 0; i < this.m_num_clusters; ++i) {
                for (int j = 0; j < this.m_num_attribs; ++j) {
                    if (this.m_theInstances.attribute(j).isNominal()) {
                        for (int k = 0; k < this.m_theInstances.attribute(j).numValues(); ++k) {
                            this.m_model[i][j].addValue(k, m.m_taskModel[i][j].getCount(k));
                        }
                        continue;
                    }
                    double[] dArray = this.m_modelNormal[i][j];
                    dArray[0] = dArray[0] + m.m_taskModelNormal[i][j][0];
                    double[] dArray2 = this.m_modelNormal[i][j];
                    dArray2[2] = dArray2[2] + m.m_taskModelNormal[i][j][2];
                    double[] dArray3 = this.m_modelNormal[i][j];
                    dArray3[1] = dArray3[1] + m.m_taskModelNormal[i][j][1];
                }
            }
        }
        this.M_reEstimate(inst);
    }

    private double iterate(Instances inst, boolean report) throws Exception {
        double llkold = 0.0;
        double llk = 0.0;
        if (report) {
            this.EM_Report(inst);
        }
        boolean ok = false;
        int seed = this.getSeed();
        int restartCount = 0;
        this.m_iterationsPerformed = -1;
        while (!ok) {
            try {
                for (int i = 0; i < this.m_max_iterations; ++i) {
                    llkold = llk;
                    llk = this.launchESteps(inst);
                    if (report) {
                        System.out.println("Loglikely: " + llk);
                    }
                    if (i > 0 && llk - llkold < this.m_minLogLikelihoodImprovementIterating) {
                        if (llk - llkold < 0.0) {
                            this.m_modelNormal = this.m_modelNormalPrev;
                            this.m_model = this.m_modelPrev;
                            this.m_priors = this.m_priorsPrev;
                            this.m_iterationsPerformed = i - 1;
                            break;
                        }
                        this.m_iterationsPerformed = i;
                        break;
                    }
                    this.launchMSteps(inst);
                }
                ok = true;
            }
            catch (Exception ex) {
                ex.printStackTrace();
                ++restartCount;
                this.m_rr = new Random(++seed);
                for (int z = 0; z < 10; ++z) {
                    this.m_rr.nextDouble();
                    this.m_rr.nextInt();
                }
                if (restartCount > 5) {
                    --this.m_num_clusters;
                    restartCount = 0;
                }
                this.EM_Init(this.m_theInstances);
                this.startExecutorPool();
            }
        }
        if (this.m_iterationsPerformed == -1) {
            this.m_iterationsPerformed = this.m_max_iterations;
        }
        if (this.m_verbose) {
            System.out.println("# iterations performed: " + this.m_iterationsPerformed);
        }
        if (report) {
            this.EM_Report(inst);
        }
        return llk;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11451 $");
    }

    public static void main(String[] argv) {
        EM.runClusterer(new EM(), argv);
    }

    private class MTask
    implements Callable<MTask> {
        protected int m_start;
        protected int m_end;
        protected Instances m_inst;
        protected DiscreteEstimator[][] m_taskModel;
        double[][][] m_taskModelNormal;

        public MTask(Instances inst, int start, int end, DiscreteEstimator[][] discEst, double[][][] numericEst) {
            this.m_start = start;
            this.m_end = end;
            this.m_inst = inst;
            this.m_taskModel = discEst;
            this.m_taskModelNormal = numericEst;
        }

        @Override
        public MTask call() {
            for (int l = this.m_start; l < this.m_end; ++l) {
                Instance in = this.m_inst.instance(l);
                for (int i = 0; i < EM.this.m_num_clusters; ++i) {
                    for (int j = 0; j < EM.this.m_num_attribs; ++j) {
                        if (this.m_inst.attribute(j).isNominal()) {
                            this.m_taskModel[i][j].addValue(in.value(j), in.weight() * EM.this.m_weights[l][i]);
                            continue;
                        }
                        double[] dArray = this.m_taskModelNormal[i][j];
                        dArray[0] = dArray[0] + in.value(j) * in.weight() * EM.this.m_weights[l][i];
                        double[] dArray2 = this.m_taskModelNormal[i][j];
                        dArray2[2] = dArray2[2] + in.weight() * EM.this.m_weights[l][i];
                        double[] dArray3 = this.m_taskModelNormal[i][j];
                        dArray3[1] = dArray3[1] + in.value(j) * in.value(j) * in.weight() * EM.this.m_weights[l][i];
                    }
                }
            }
            return this;
        }
    }

    private class ETask
    implements Callable<double[]> {
        protected int m_lowNum;
        protected int m_highNum;
        protected boolean m_changeWeights;
        protected Instances m_eData;

        public ETask(Instances data, int lowInstNum, int highInstNum, boolean changeWeights) {
            this.m_eData = data;
            this.m_lowNum = lowInstNum;
            this.m_highNum = highInstNum;
            this.m_changeWeights = changeWeights;
        }

        @Override
        public double[] call() {
            double[] llk = new double[2];
            double loglk = 0.0;
            double sOW = 0.0;
            try {
                for (int i = this.m_lowNum; i < this.m_highNum; ++i) {
                    Instance in = this.m_eData.instance(i);
                    loglk += in.weight() * EM.this.logDensityForInstance(in);
                    sOW += in.weight();
                    if (!this.m_changeWeights) continue;
                    ((EM)EM.this).m_weights[i] = EM.this.distributionForInstance(in);
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
            llk[0] = loglk;
            llk[1] = sOW;
            return llk;
        }
    }
}

