/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.regression;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.junit.Assert;
import org.junit.Test;

public class JavaRidgeRegressionSuite
extends SharedSparkSession {
    private static double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) {
        double errorSum = 0.0;
        for (LabeledPoint point : validationData) {
            double prediction = model.predict(point.features());
            errorSum += (prediction - point.label()) * (prediction - point.label());
        }
        return errorSum / (double)validationData.size();
    }

    private static List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
        Random random = new Random(42L);
        double[] w = new double[numFeatures];
        for (int i = 0; i < w.length; ++i) {
            w[i] = random.nextDouble() - 0.5;
        }
        return LinearDataGenerator.generateLinearInputAsList((double)0.0, (double[])w, (int)numPoints, (int)42, (double)std);
    }

    @Test
    public void runRidgeRegressionUsingConstructor() {
        int numExamples = 50;
        int numFeatures = 20;
        List<LabeledPoint> data = JavaRidgeRegressionSuite.generateRidgeData(2 * numExamples, numFeatures, 10.0);
        JavaRDD testRDD = this.jsc.parallelize(new ArrayList<LabeledPoint>(data.subList(0, numExamples)));
        List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
        RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0);
        RidgeRegressionModel model = (RidgeRegressionModel)ridgeSGDImpl.run(testRDD.rdd());
        double unRegularizedErr = JavaRidgeRegressionSuite.predictionError(validationData, model);
        ridgeSGDImpl.optimizer().setRegParam(0.1);
        model = (RidgeRegressionModel)ridgeSGDImpl.run(testRDD.rdd());
        double regularizedErr = JavaRidgeRegressionSuite.predictionError(validationData, model);
        Assert.assertTrue((regularizedErr < unRegularizedErr ? 1 : 0) != 0);
    }

    @Test
    public void runRidgeRegressionUsingStaticMethods() {
        int numExamples = 50;
        int numFeatures = 20;
        List<LabeledPoint> data = JavaRidgeRegressionSuite.generateRidgeData(2 * numExamples, numFeatures, 10.0);
        JavaRDD testRDD = this.jsc.parallelize(new ArrayList<LabeledPoint>(data.subList(0, numExamples)));
        List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
        RidgeRegressionModel model = (RidgeRegressionModel)new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0).run(testRDD.rdd());
        double unRegularizedErr = JavaRidgeRegressionSuite.predictionError(validationData, model);
        model = (RidgeRegressionModel)new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0).run(testRDD.rdd());
        double regularizedErr = JavaRidgeRegressionSuite.predictionError(validationData, model);
        Assert.assertTrue((regularizedErr < unRegularizedErr ? 1 : 0) != 0);
    }
}

