package icu.wuhufly.features

import icu.wuhufly.SparkHandler
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Normalizer, OneHotEncoder, StandardScaler, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}

object Pred02 {
  def main(args: Array[String]): Unit = {
    val handler: SparkHandler = SparkHandler.of()
    val spark: SparkSession = handler.getSpark()
    val sc: SparkContext = spark.sparkContext
    spark.sql("use hudi_gy_dwd")
    import spark.implicits._

    val Array(train, test) = handler.readFromHDFS("fact_machine_learning_data", spark, "hudi_gy_dwd")
      .randomSplit(Array(0.8,0.2), 42)

    val rfc = new RandomForestClassifier()
      .setSeed(42)
      .setLabelCol("machine_record_state")
      .setFeaturesCol("features")
    val pipeline: Pipeline = new Pipeline()
      .setStages(Array(
        new VectorAssembler()
          .setInputCols(Array("machine_record_mainshaft_speed", "machine_record_mainshaft_multiplerate", "machine_record_mainshaft_load", "machine_record_feed_speed", "machine_record_feed_multiplerate", "machine_record_pmc_code", "machine_record_circle_time", "machine_record_run_time", "machine_record_effective_shaft", "machine_record_amount_process", "machine_record_use_memory", "machine_record_free_memory", "machine_record_amount_use_code", "machine_record_amount_free_code"))
          .setOutputCol("ass_features"),
        new OneHotEncoder()
          .setInputCol("machine_id")
          .setOutputCol("ohe1"),
        new StandardScaler()
          .setInputCol("ass_features")
          .setOutputCol("std_features"),
        new Normalizer()
          .setInputCol("std_features")
          .setOutputCol("features"),
        rfc
      ))

    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("machine_record_state")

    val paramMaps: Array[ParamMap] = new ParamGridBuilder()
      .addGrid(rfc.maxBins, Array(32, 48))
      .addGrid(rfc.numTrees, Array(20, 40))
      .addGrid(rfc.maxDepth, Array(5, 10))
      .build()

    val resDF: DataFrame = new CrossValidator()
      .setSeed(42)
      .setNumFolds(3)
      .setEvaluator(evaluator)
      .setEstimator(pipeline)
      .setParallelism(6)
      .setEstimatorParamMaps(paramMaps)
      .fit(train)
      .transform(test)
      .selectExpr("machine_record_id", "prediction as machine_record_state")

    handler.writeIntoMysql(
      "ml_result", resDF
    )

    sc.stop()
  }
}
