/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.util;

import java.util.Arrays;
import java.util.Collections;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.MatrixUDT;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.Assert;
import org.junit.Test;

public class JavaMLUtilsSuite
extends SharedSparkSession {
    @Test
    public void testConvertVectorColumnsToAndFromML() {
        Vector x = Vectors.dense((double)2.0, (double[])new double[0]);
        Dataset dataset = this.spark.createDataFrame(Collections.singletonList(new LabeledPoint(1.0, x)), LabeledPoint.class).select("label", new String[]{"features"});
        Dataset newDataset1 = MLUtils.convertVectorColumnsToML((Dataset)dataset, (String[])new String[0]);
        Row new1 = (Row)newDataset1.first();
        Assert.assertEquals((Object)RowFactory.create((Object[])new Object[]{1.0, x.asML()}), (Object)new1);
        Row new2 = (Row)MLUtils.convertVectorColumnsToML((Dataset)dataset, (String[])new String[]{"features"}).first();
        Assert.assertEquals((Object)new1, (Object)new2);
        Row old1 = (Row)MLUtils.convertVectorColumnsFromML((Dataset)newDataset1, (String[])new String[0]).first();
        Assert.assertEquals((Object)RowFactory.create((Object[])new Object[]{1.0, x}), (Object)old1);
    }

    @Test
    public void testConvertMatrixColumnsToAndFromML() {
        Matrix x = Matrices.dense((int)2, (int)1, (double[])new double[]{1.0, 2.0});
        StructType schema = new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", (DataType)new MatrixUDT(), false, Metadata.empty())});
        Dataset dataset = this.spark.createDataFrame(Arrays.asList(RowFactory.create((Object[])new Object[]{1.0, x})), schema);
        Dataset newDataset1 = MLUtils.convertMatrixColumnsToML((Dataset)dataset, (String[])new String[0]);
        Row new1 = (Row)newDataset1.first();
        Assert.assertEquals((Object)RowFactory.create((Object[])new Object[]{1.0, x.asML()}), (Object)new1);
        Row new2 = (Row)MLUtils.convertMatrixColumnsToML((Dataset)dataset, (String[])new String[]{"features"}).first();
        Assert.assertEquals((Object)new1, (Object)new2);
        Row old1 = (Row)MLUtils.convertMatrixColumnsFromML((Dataset)newDataset1, (String[])new String[0]).first();
        Assert.assertEquals((Object)RowFactory.create((Object[])new Object[]{1.0, x}), (Object)old1);
    }
}

