/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.scala.example.spark;

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineModel$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.immutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;
import scala.sys.package$;

public final class SparkMLlibPipeline$ {
    public static final SparkMLlibPipeline$ MODULE$;

    static {
        new SparkMLlibPipeline$();
    }

    public void main(String[] args) {
        if (args.length != 3) {
            Predef$.MODULE$.println((Object)"Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path");
            throw package$.MODULE$.exit(1);
        }
        String inputPath = args[0];
        String nativeModelPath = args[1];
        String pipelineModelPath = args[2];
        SparkSession spark = SparkSession$.MODULE$.builder().appName("XGBoost4J-Spark Pipeline Example").getOrCreate();
        StructType schema = new StructType((StructField[])((Object[])new StructField[]{new StructField("sepal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("sepal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("class", (DataType)StringType$.MODULE$, true, StructField$.MODULE$.apply$default$4())}));
        Dataset rawInput = spark.read().schema(schema).csv(inputPath);
        Dataset[] datasetArray = rawInput.randomSplit(new double[]{0.8, 0.2}, 123L);
        Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
        if (!option.isEmpty() && option.get() != null && ((SeqLike)option.get()).lengthCompare(2) == 0) {
            Tuple2 tuple2;
            Dataset training = (Dataset)((SeqLike)option.get()).apply(0);
            Dataset test = (Dataset)((SeqLike)option.get()).apply(1);
            Tuple2 tuple22 = tuple2 = new Tuple2((Object)training, (Object)test);
            Dataset training2 = (Dataset)tuple22._1();
            Dataset test2 = (Dataset)tuple22._2();
            VectorAssembler assembler = new VectorAssembler().setInputCols((String[])((Object[])new String[]{"sepal length", "sepal width", "petal length", "petal width"})).setOutputCol("features");
            StringIndexerModel labelIndexer = new StringIndexer().setInputCol("class").setOutputCol("classIndex").fit(training2);
            XGBoostClassifier booster = new XGBoostClassifier((Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eta"), (Object)BoxesRunTime.boxToFloat((float)0.1f)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"max_depth"), (Object)BoxesRunTime.boxToInteger((int)2)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"objective"), (Object)"multi:softprob"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_class"), (Object)BoxesRunTime.boxToInteger((int)3)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_round"), (Object)BoxesRunTime.boxToInteger((int)100)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"num_workers"), (Object)BoxesRunTime.boxToInteger((int)2))})));
            booster.setFeaturesCol("features");
            booster.setLabelCol("classIndex");
            IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("realLabel").setLabels(labelIndexer.labels());
            Pipeline pipeline = new Pipeline().setStages((PipelineStage[])((Object[])new PipelineStage[]{assembler, labelIndexer, booster, labelConverter}));
            PipelineModel model = pipeline.fit(training2);
            Dataset prediction = model.transform(test2);
            prediction.show(false);
            MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator();
            evaluator.setLabelCol("classIndex");
            evaluator.setPredictionCol("prediction");
            double accuracy = evaluator.evaluate(prediction);
            Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"The model accuracy is : ").append((Object)BoxesRunTime.boxToDouble((double)accuracy)).toString());
            ParamMap[] paramGrid = new ParamGridBuilder().addGrid(booster.maxDepth(), new int[]{3, 8}).addGrid(booster.eta(), new double[]{0.2, 0.6}).build();
            CrossValidator cv = new CrossValidator().setEstimator((Estimator)pipeline).setEvaluator((Evaluator)evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3);
            CrossValidatorModel cvModel = cv.fit(training2);
            XGBoostClassificationModel bestModel = (XGBoostClassificationModel)((PipelineModel)cvModel.bestModel()).stages()[2];
            Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"The params of best XGBoostClassification model : ").append((Object)bestModel.extractParamMap()).toString());
            Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"The training summary of best XGBoostClassificationModel : ").append((Object)bestModel.summary()).toString());
            bestModel.nativeBooster().saveModel(nativeModelPath);
            model.write().overwrite().save(pipelineModelPath);
            PipelineModel model2 = PipelineModel$.MODULE$.load(pipelineModelPath);
            model2.transform(test2).show(false);
            return;
        }
        throw new MatchError((Object)datasetArray);
    }

    private SparkMLlibPipeline$() {
        MODULE$ = this;
    }
}

