/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.scala.example.spark;

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.runtime.BoxesRunTime;
import scala.sys.package$;

public final class SparkTraining$ {
    public static final SparkTraining$ MODULE$;

    static {
        new SparkTraining$();
    }

    public void main(String[] args) {
        if (args.length < 1) {
            Predef$.MODULE$.println((Object)"Usage: program input_path");
            throw package$.MODULE$.exit(1);
        }
        SparkSession spark = SparkSession$.MODULE$.builder().getOrCreate();
        String inputPath = args[0];
        StructType schema = new StructType((StructField[])((Object[])new StructField[]{new StructField("sepal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("sepal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("class", (DataType)StringType$.MODULE$, true, StructField$.MODULE$.apply$default$4())}));
        Dataset rawInput = spark.read().schema(schema).csv(args[0]);
        StringIndexerModel stringIndexer = new StringIndexer().setInputCol("class").setOutputCol("classIndex").fit(rawInput);
        Dataset labelTransformed = stringIndexer.transform(rawInput).drop("class");
        VectorAssembler vectorAssembler = new VectorAssembler().setInputCols((String[])((Object[])new String[]{"sepal length", "sepal width", "petal length", "petal width"})).setOutputCol("features");
        Dataset xgbInput = vectorAssembler.transform(labelTransformed).select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"classIndex"}));
        Map xgbParam = (Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eta"), (Object)BoxesRunTime.boxToFloat((float)0.1f)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"max_depth"), (Object)BoxesRunTime.boxToInteger((int)2)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"objective"), (Object)"multi:softprob"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_class"), (Object)BoxesRunTime.boxToInteger((int)3)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_round"), (Object)BoxesRunTime.boxToInteger((int)100)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_workers"), (Object)BoxesRunTime.boxToInteger((int)2))}));
        XGBoostClassifier xgbClassifier = (XGBoostClassifier)new XGBoostClassifier(xgbParam).setFeaturesCol("features").setLabelCol("classIndex");
        XGBoostClassificationModel xgbClassificationModel = (XGBoostClassificationModel)xgbClassifier.fit(xgbInput);
        Dataset results = xgbClassificationModel.transform(xgbInput);
        results.show();
    }

    private SparkTraining$() {
        MODULE$ = this;
    }
}

