/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import java.io.IOException;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;

public class JavaLinearRegressionSuite
extends SharedSparkSession {
    private transient Dataset<Row> dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Override
    public void setUp() throws IOException {
        super.setUp();
        List<LabeledPoint> points = LogisticRegressionSuite.generateLogisticInputAsList(1.0, 1.0, 100, 42);
        this.datasetRDD = this.jsc.parallelize(points, 2);
        this.dataset = this.spark.createDataFrame(this.datasetRDD, LabeledPoint.class);
        this.dataset.createOrReplaceTempView("dataset");
    }

    @Test
    public void linearRegressionDefaultParams() {
        LinearRegression lr = new LinearRegression();
        Assert.assertEquals((Object)"label", (Object)lr.getLabelCol());
        Assert.assertEquals((Object)"auto", (Object)lr.getSolver());
        LinearRegressionModel model = (LinearRegressionModel)lr.fit(this.dataset);
        model.transform(this.dataset).createOrReplaceTempView("prediction");
        Dataset predictions = this.spark.sql("SELECT label, prediction FROM prediction");
        predictions.collect();
        Assert.assertEquals((Object)"features", (Object)model.getFeaturesCol());
        Assert.assertEquals((Object)"prediction", (Object)model.getPredictionCol());
    }

    @Test
    public void linearRegressionWithSetters() {
        LinearRegression lr = new LinearRegression().setMaxIter(10).setRegParam(1.0).setSolver("l-bfgs");
        LinearRegressionModel model = (LinearRegressionModel)lr.fit(this.dataset);
        LinearRegression parent = (LinearRegression)model.parent();
        Assert.assertEquals((long)10L, (long)parent.getMaxIter());
        Assert.assertEquals((double)1.0, (double)parent.getRegParam(), (double)0.0);
        LinearRegressionModel model2 = (LinearRegressionModel)lr.fit(this.dataset, lr.maxIter().w(5), new ParamPair[]{lr.regParam().w(0.1), lr.predictionCol().w((Object)"thePred")});
        LinearRegression parent2 = (LinearRegression)model2.parent();
        Assert.assertEquals((long)5L, (long)parent2.getMaxIter());
        Assert.assertEquals((double)0.1, (double)parent2.getRegParam(), (double)0.0);
        Assert.assertEquals((Object)"thePred", (Object)model2.getPredictionCol());
    }
}

