/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import com.google.common.primitives.Doubles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.Array;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.clustering.ClusteringModelUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public class KMeansModelConverter
extends ModelConverter<KMeansModel> {
    public KMeansModelConverter(KMeansModel model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLUSTERING;
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        KMeansModel model = (KMeansModel)this.getTransformer();
        return Collections.emptyList();
    }

    public ClusteringModel encodeModel(Schema schema) {
        KMeansModel model = (KMeansModel)this.getTransformer();
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        Vector[] clusterCenters = model.clusterCenters();
        for (int i = 0; i < clusterCenters.length; ++i) {
            Vector clusterCenter = clusterCenters[i];
            Array array = PMMLUtil.createRealArray((List)Doubles.asList((double[])clusterCenter.toArray()));
            Cluster cluster = new Cluster().setId(String.valueOf(i)).setArray(array);
            clusters.add(cluster);
        }
        List features = schema.getFeatures();
        List clusteringFields = ClusteringModelUtil.createClusteringFields((List)features);
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure((Measure)new SquaredEuclidean());
        ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema((Schema)schema), comparisonMeasure, clusteringFields, clusters).setOutput(ClusteringModelUtil.createOutput((FieldName)FieldName.create((String)"cluster"), Collections.emptyList()));
        return clusteringModel;
    }
}

