/*
 * Decompiled with CFR 0.152.
 */
package sklearn.cluster;

import com.google.common.collect.HashMultiset;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.Array;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.clustering.ClusteringModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Clusterer;

public class KMeans
extends Clusterer {
    public KMeans(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getClusterCentersShape();
        return shape[1];
    }

    public ClusteringModel encodeModel(Schema schema) {
        int[] shape = this.getClusterCentersShape();
        int numberOfClusters = shape[0];
        int numberOfFeatures = shape[1];
        List<? extends Number> clusterCenters = this.getClusterCenters();
        List<Integer> labels = this.getLabels();
        HashMultiset labelCounts = HashMultiset.create();
        if (labels != null) {
            labelCounts.addAll(labels);
        }
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (int i = 0; i < numberOfClusters; ++i) {
            Array array = PMMLUtil.createRealArray((List)CMatrixUtil.getRow(clusterCenters, (int)numberOfClusters, (int)numberOfFeatures, (int)i));
            Cluster cluster = new Cluster().setId(String.valueOf(i)).setSize(labelCounts.size() > 0 ? Integer.valueOf(labelCounts.count((Object)i)) : null).setArray(array);
            clusters.add(cluster);
        }
        List features = schema.getFeatures();
        List clusteringFields = ClusteringModelUtil.createClusteringFields((List)features);
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure((Measure)new SquaredEuclidean());
        ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, numberOfClusters, ModelUtil.createMiningSchema((Schema)schema), comparisonMeasure, clusteringFields, clusters).setOutput(ClusteringModelUtil.createOutput((FieldName)FieldName.create((String)"Cluster"), clusters));
        return clusteringModel;
    }

    public List<? extends Number> getClusterCenters() {
        return ClassDictUtil.getArray(this, "cluster_centers_");
    }

    public List<Integer> getLabels() {
        return ValueUtil.asIntegers(ClassDictUtil.getArray(this, "labels_"));
    }

    private int[] getClusterCentersShape() {
        return ClassDictUtil.getShape(this, "cluster_centers_", 2);
    }
}

