/*
 * Decompiled with CFR 0.152.
 */
package sklearn.cluster;

import com.google.common.collect.HashMultiset;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.clustering.ClusteringModelUtil;
import sklearn.Clusterer;

public class KMeans
extends Clusterer {
    public KMeans(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getClusterCentersShape();
        return shape[1];
    }

    public ClusteringModel encodeModel(Schema schema) {
        int[] shape = this.getClusterCentersShape();
        int numberOfClusters = shape[0];
        int numberOfFeatures = shape[1];
        List<? extends Number> clusterCenters = this.getClusterCenters();
        List<Integer> labels = this.getLabels();
        HashMultiset labelCounts = HashMultiset.create();
        if (labels != null) {
            labelCounts.addAll(labels);
        }
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (int i = 0; i < numberOfClusters; ++i) {
            Cluster cluster = new Cluster(PMMLUtil.createRealArray((List)CMatrixUtil.getRow(clusterCenters, (int)numberOfClusters, (int)numberOfFeatures, (int)i))).setId(String.valueOf(i)).setSize(labelCounts.size() > 0 ? Integer.valueOf(labelCounts.count((Object)i)) : null);
            clusters.add(cluster);
        }
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, (Measure)new SquaredEuclidean()).setCompareFunction(CompareFunction.ABS_DIFF);
        ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, Integer.valueOf(numberOfClusters), ModelUtil.createMiningSchema((Label)schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields((List)schema.getFeatures()), clusters).setOutput(ClusteringModelUtil.createOutput((FieldName)FieldName.create((String)"Cluster"), (DataType)DataType.DOUBLE, clusters));
        return clusteringModel;
    }

    public List<? extends Number> getClusterCenters() {
        return this.getNumberArray("cluster_centers_");
    }

    public int[] getClusterCentersShape() {
        return this.getArrayShape("cluster_centers_", 2);
    }

    public List<Integer> getLabels() {
        return this.getIntegerArray("labels_");
    }
}

