// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.mothra.tools

import org.cert.netsa.io.ipfix.InfoModel
import org.cert.netsa.mothra.packer.{
  PackerThreadFactory, Reader, Version, Writer}

import com.typesafe.scalalogging.StrictLogging
import java.io.{PrintWriter, StringWriter}
import java.lang.management.ManagementFactory
import java.util.concurrent.{Executors, LinkedBlockingQueue,
  ThreadPoolExecutor, TimeUnit}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{LocatedFileStatus, RemoteIterator}
import org.apache.hadoop.fs.{Path => HPath}
import org.apache.hadoop.io.compress.CompressionCodecFactory
import resource.managed // see http://jsuereth.com/scala-arm/index.html
import scala.collection.mutable.{HashMap, MultiMap, Set}
import scala.util.control.NonFatal
import scala.util.matching.Regex
import scala.util.{Failure, Success, Try}

/**
  * Object to implement the FileJoiner application.
  *
  * Typical Usage in a Spark environment:
  *
  * `spark-submit --class org.cert.netsa.mothra.packer.tools.FileJoinerMain mothra-tools.jar <s1> [<s2> <s3> ...]`
  *
  * where:
  *
  * s1..sn:         Directories to process, as Hadoop URIs
  *
  *
  * FileJoiner reduces the number of data files in a Mothra repository.  It
  * may also be used to modify the files' compression.
  *
  * FileJoiner runs as a batch process, not as a daemon.
  *
  * FileJoiner makes a single recursive scan of the source directories <s1>,
  * <s2>, ... for files whose names match the pattern "YYYYMMDD.HH." or
  * "YYYYMMDD.HH-PTddH." (It looks for files matching the regular expression
  * `^\d{8}\.\d{2}(?:-PT\d\d?H)?\.`) Files whose names match that pattern are
  * processed by FileJoiner to create a single new file in the same directory
  * that has the same prefix as the originals, and then the original file(s)
  * are removed.
  *
  * By default, files that share the same prefix are only processed when there
  * are two or more files.  To force re-writing when there is a single file,
  * set the Java property `mothra.filejoiner.minCountToJoin` to a value less
  * than 2.  The property may also be used to create a new file only when an
  * "excessive" number of files share the same prefix.
  *
  * There is always a single thread that recursively scans the directories.
  * The number of threads that joins the files may be set by specifying the
  * `mothra.filejoiner.maxThreads` Java property.  If not specified, the
  * default is 6.
  *
  * FileJoiner may be run so that either it spawns a thread for every
  * directory that contains files to be joined or it spawns a thread for each
  * set of files in a directory that have the same prefix.  The behavior is
  * controlled whether the `mothra.filejoiner.spawnThread` Java property is
  * set to `by-prefix` or `by-directory`.  The default is `by-directory`.
  * (For backwards compatibility, `by-hour` is an alias for `by-prefix`.)
  *
  * By default, FileJoiner does not compress the files it writes.
  * (NOTE: It should support writing the output using the same compression as
  * the input.)  To specify the compression codec that it should use, specify
  * the `mothra.filejoiner.compression` Java property.  Values typically
  * supported by Hadoop include `bzip2`, `gzip`, `lz4`, `lzo`, `lzop`,
  * `snappy`, and `default`.  The empty string indicates no compression.
  *
  * FileJoiner joins files sharing the same prefix into a single file by
  * default.  The `mothra.filejoiner.maximumSize` Java property may be used to
  * limit the maximum file size.  The size is for the compressed file if
  * compression is active.  The value is approximate since it is only checked
  * after the data appears on disk which occurs in large blocks because of
  * buffering by the Java stream code and the compression algorithm.  (By
  * setting that property and `mothra.filejoiner.minCountToJoin` to 1, you can
  * force large files to be split into smaller ones, making the FileJoiner a
  * file-splitter.)
  *
  */
object FileJoinerMain extends App with StrictLogging {

  def usage(full: Boolean = false): Unit = {
    print("""
Usage: spark-submit --class org.cert.netsa.mothra.packer.tools.FileJoinerMain mothra-tools.jar <s1> [<s2> <s3> ...]

s1..sn:         Directories to process, as Hadoop URIs
""")
    if ( full ) {
      print(s"""
FileJoiner reduces the number of data files in a Mothra repository.  It
may also be used to modify the files' compression.

FileJoiner runs as a batch process, not as a daemon.

FileJoiner makes a single recursive scan of the source directories <s1>,
<s2>, ... for files whose names match the pattern "YYYYMMDD.HH." or
"YYYYMMDD.HH-PTddH." (It looks for files matching the regular expression
`^\\d{8}\\.\\d{2}(?:-PT\\d\\d?H)?\\.`) Files whose names match that pattern are
processed by FileJoiner to create a single new file in the same directory
that has the same prefix as the originals, and then the original file(s)
are removed.

By default, files that share the same prefix are only processed when there
are two or more files.  To force re-writing when there is a single file,
set the Java property `mothra.filejoiner.minCountToJoin` to a value less
than 2.  The property may also be used to create a new file only when an
"excessive" number of files share the same prefix.

There is always a single thread that recursively scans the directories.
The number of threads that joins the files may be set by specifying the
`mothra.filejoiner.maxThreads` Java property.  If not specified, the
default is ${DEFAULT_MAX_THREADS}.

FileJoiner may be run so that either it spawns a thread for every
directory that contains files to be joined or it spawns a thread for each
set of files in a directory that have the same prefix.  The behavior is
controlled whether the `mothra.filejoiner.spawnThread` Java property is
set to `by-prefix` or `by-directory`.  The default is `${DEFAULT_SPAWN_THREAD}`.
(For backwards compatibility, `by-hour` is an alias for `by-prefix`.)

By default, FileJoiner does not compress the files it writes.
(NOTE: It should support writing the output using the same compression as
the input.)  To specify the compression codec that it should use, specify
the `mothra.filejoiner.compression` Java property.  Values typically
supported by Hadoop include `bzip2`, `gzip`, `lz4`, `lzo`, `lzop`,
`snappy`, and `default`.  The empty string indicates no compression.

FileJoiner joins files sharing the same prefix into a single file by
default.  The `mothra.filejoiner.maximumSize` Java property may be used to
limit the maximum file size.  The size is for the compressed file if
compression is active.  The value is approximate since it is only checked
after the data appears on disk which occurs in large blocks because of
buffering by the Java stream code and the compression algorithm.  (By
setting that property and `mothra.filejoiner.minCountToJoin` to 1, you can
force large files to be split into smaller ones, making the FileJoiner a
file-splitter.)
""")
    }
    System.exit(if (full) { 0 } else { 1 })
  }

  def version(): Unit = {
    println("FileJoiner " + Version.get())
    System.exit(0)
  }

  //  <<<<<  DEFAULT VALUES FOR PROPERTY SETTINGS  >>>>>  //

  /**
    * The default compression codec to use for files written to HDFS.  This
    * may be modified by specifying the following property:
    * mothra.filejoiner.compression.
    *
    * Values typically supported by Hadoop include `bzip2`, `gzip`, `lz4`,
    * `lzo`, `lzop`, `snappy`, and `default`.  The empty string indicates no
    * compression.
    */
  val DEFAULT_COMPRESSION = ""

  /**
    * The default number of threads to run for joining files when the
    * `mothra.filejoiner.maxThreads` Java property is not set. (The scanning
    * task always runs in its own thread.)
    */
  val DEFAULT_MAX_THREADS = 6

  /**
    * The default value for `spawnThread` when the
    * `mothra.filejoiner.spawnThread` Java property is not specified.
    */
  val DEFAULT_SPAWN_THREAD = "by-directory"


  //  <<<<<  PROCESS THE COMMAND LINE ARGUMENTS  >>>>>  //

  val (switches, positionalArgs) = args.partition { _.substring(0, 1) == "-" }

  switches.collect {
    case "-V" | "--version" => version()
    case "-h" | "--help" => usage(true)
    case unknown: String =>
      println(s"Unknown argument '${unknown}'")
      usage()
  }

  logger.info("\n=============================" +
    " FileJoiner is starting =============================\n")
  logger.info(s"This is FileJoiner ${Version.get()}")

  if ( positionalArgs.length == 0 ) {
    logger.error(s"Called with no args; at least 1 required")
    usage()
  }

  /**
    * Joins the files specified in `files` that all exist in `dir` and
    * all of whose names begin with the same `basename` which is of the form
    * "YYYYMMDD.HH." or "YYYYMMDD.HH-PTddH."
    */
  private[this] def joinFilesBasename(
    dir: HPath, basename: String, files: Set[HPath]): Unit =
  {
    // file currently being written to
    var writer: Writer = null

    // list of newly created files
    var newPaths = List.empty[HPath]

    // list of files that were successfully processed
    var removeList = List.empty[HPath]

    logger.debug(s"Joining ${files.size} '${basename}*' files in ${dir}/")
    val t0 = System.currentTimeMillis()

    Try {
      writer = Writer(dir, basename, compressCodec, maximumSize)
      val originalPerm = writer.originalPermission
      if ( maximumSize.isEmpty ) {
        // process all input files
        for ( f <- files ) {
          // process all records in the input
          for ( reader <- managed(Reader(f, codecFactory)) ) {
            for ( record <- reader ) {
              writer.add(record)
            }
            removeList = f +: removeList
          }
        }
      } else {
        // process all input files
        for ( f <- files ) {
          // process all records in the input
          for ( reader <- managed(Reader(f, codecFactory)) ) {
            for ( record <- reader ) {
              if ( writer.reachedMaxSize ) {
                logger.trace(s"Closing file '${writer.getName}'")
                writer.close()
                newPaths = writer.exportFile +: newPaths
                writer = null
                logger.trace("Creating additional writer for" +
                  s" '${basename}*' files in ${dir}")
                writer = Writer(dir, basename, compressCodec, maximumSize)
              }
              writer.add(record)
            }
            removeList = f +: removeList
          }
        }
      }
      writer.close()
      newPaths = writer.exportFile +: newPaths
      writer = null
      // restore the original permission bits on the new files
      for ( perm <- originalPerm ; f <- newPaths ) {
        fileSystem.setPermission(f, perm)
      }
      //logger.trace(s"Removing old '${basename}*' files from ${dir}/")
      for ( f <- removeList ) {
        Try {fileSystem.delete(f, false)} match {
          case Failure(e) =>
            logger.warn(
              s"Failed to remove old file '${f}': ${e.toString}")
          case _ =>
        }
      }
    } match {
      case Success(ok) =>
        logger.debug("Finished joining " +
          s"${files.size} '${basename}*' files into ${newPaths.size} files" +
          s" in ${dir}/ in " +
          f"${(System.currentTimeMillis() - t0).toDouble/1000.0}%.3f" +
          " seconds")
      case Failure(e) =>
        logger.error(
          s"Failed to join ${files.size} '${basename}*' files in ${dir}/" +
            s": ${e.toString}")
        for ( w <- Option(writer) ) { newPaths = w.exportFile +: newPaths }
        for ( f <- newPaths ) {
          Try { fileSystem.delete(f, false) } match {
            case Success(ok) =>
            case Failure(e) =>
                logger.error("Failed to remove new file"
                  + s" '${f.getName}' in ${dir}/: ${e.toString}")
          }
        }
    }
  }



  //  <<<<<  DEFINE SOME HELPER CLASSES  >>>>>  //

  /**
    * A Runnable that joins the files that have the same basename.  It assumes
    * all entries in `files` begin with `basename` and live in `dir`.
    */
  private[this] case class BasenameFilesJob(
    dir: HPath, basename: String, files: Set[HPath])
      extends Runnable
  {
    def run(): Unit = {
      joinFilesBasename(dir, basename, files)
      signalQueue.add(0)
    }
  }


  /**
    * A Runnable that splits the files in `files` by unique basename and, for
    * each basename, joins those files into a new file.  It assumes all
    * entries in `files` live in `dir`.
    */
  private[this] case class DirectoryJob(
    dir: HPath, files: HashMap[String, Set[HPath]] with MultiMap[String, HPath])
      extends Runnable
  {
    def run(): Unit = {
      // Create threads to process the files for each unique basename in this
      // directory
      for ( (basename, set) <- files ) {
        if ( set.size >= minCountToJoin ) {
          joinFilesBasename(dir, basename, set)
        }
      }
      signalQueue.add(0)
    }
  }


  // /////  Constants & Values Determined from Properties  /////


  /** The information model */
  implicit val infoModel = InfoModel.getCERTStandardInfoModel()

  /** The Hadoop configuration */
  implicit val hadoopConf = new Configuration()

  /** A Compression Codec Factory */
  private[this] val codecFactory = new CompressionCodecFactory(hadoopConf)

  /**
    * The compression codec used for files written to HDFS.  This may be set
    * by setting the "mothra.filejoiner.compression" property.  If that
    * property is not set, DEFAULT_COMPRESSION is used.
    */
  val compressCodec = {
    val compressName = sys.props.get("mothra.filejoiner.compression").
      getOrElse(DEFAULT_COMPRESSION)
    if ( compressName == "" ) {
      //logger.info("Using no compression for IPFIX files")
      None
    } else {
      Try {
        //logger.trace(s"have a name ${compressName}")
        val codec = codecFactory.getCodecByName(compressName)
        //logger.trace(s"have a codec ${codec}")
        // Make sure we can create a compressor, not using it here.
        codec.createCompressor()
        //logger.trace(s"have a compressor ${compressor}")
        codec
      } match {
        case Success(ok) =>
          //logger.info(s"Using ${compressName} compressor for IPFIX files")
          Option(ok)
        case Failure(e) =>
          logger.error("Unable to initialize compressor" +
            s" '${compressName}': ${e.toString}")
          val sw = new StringWriter
          e.printStackTrace(new PrintWriter(sw))
          logger.info("Unable to initialize compressor" +
            s" '${compressName}': ${sw.toString}")
          logger.warn("Using no compression for IPFIX files")
          None
      }
    }
  }
  //logger.trace(s"compressCodec is ${compressCodec}")

  /**
    * The maximum number of filejoiner threads to start.  It defaults to the
    * value `DEFAULT_MAX_THREADS`.
    *
    * This run-time behavior may be modified by setting the
    * mothra.filejoiner.maxThreads property.
    */
  val maxThreads = (sys.props.get("mothra.filejoiner.maxThreads").
    map { _.toInt }).getOrElse(DEFAULT_MAX_THREADS)
  require(maxThreads >= 1)

  /**
    * The number of files that must exist having the same "YYYYMMDD.HH" or
    * "YYYYMMDD.HH-PTddH" prefix (in a single directory) for those files to be
    * joined into a larger file.  The default is 2.
    *
    * This may be modified by setting the mothra.filejoiner.minCountToJoin
    * property.  For example, when changing the compression, you may want to
    * modify all files by setting this to 1, even if multiple files do not
    * need to be joined.
    */
  val minCountToJoin = (sys.props.get("mothra.filejoiner.minCountToJoin").
    map { _.toInt }).getOrElse(2)

  /**
    * The (approximate) maximum size file to create.  The default is no
    * maximum.  When a file's size exceeds this value, the file is closed and
    * a new file is started.  Typically a file's size will not exceed this
    * value by more than the maximum size of an IPFIX message, 64k.
    */
  val maximumSize = (sys.props.get("mothra.filejoiner.maximumSize").
    map { _.toLong })

  /**
    *
    * The behavior as to whether a file-joining thread is spawned...
    *
    * `by-directory`: for every directory that contains files to be joined, or
    *
    * `by-prefix`: for every unqiue basename prefix (that is, the file name
    * without the UUID) (in a single directory) that contains files to be
    * joined.  `by-hour` is an alias for `by-prefix`.
    *
    * The default is specified by the `DEFAULT_SPAWN_THREAD` variable.  The
    * run-time behavior may be modified by setting the
    * `mothra.filejoiner.spawnThread` Java property to one of those values.
    */
  val spawnThread = sys.props.get("mothra.filejoiner.spawnThread").
    getOrElse(DEFAULT_SPAWN_THREAD)

  /** Mapping from `spawnThread` value to `threadPerDirectory`. */
  val spawnThreadMap = Map(
    "by-directory" -> true, "by-prefix" -> false, "by-hour" -> false)

  private[this] val threadPerDirectory = Try {spawnThreadMap(spawnThread)}.
    getOrElse {
      val sb = new StringBuilder()
      spawnThreadMap.keys.addString(
        sb, "mothra.filejoiner.spawnThread must be one of: '", "', '", "'")
      throw new Exception(sb.mkString)
    }

  // /////  FileJoiner procedural code begins here  /////

  // the argument(s) is/are the directory(s) to scan
  private[this] var dirList = positionalArgs.toList.map { new HPath(_) }

  // ensure all source directories use the same file system
  val fileSystem = dirList.head.getFileSystem(hadoopConf)
  if ( dirList.drop(1).exists{_.getFileSystem(hadoopConf) != fileSystem} )
  {
    logger.error("source directories use different file systems")
    throw new Exception("source directories use different file systems")
  }

  // log our settings
  logger.info("FileJoiner settings::")
  logger.info(s"Number of top-level directories to scan: ${dirList.size}")
  logger.info(s"Maximum number of file joining threads: ${maxThreads}")
  logger.info(s"Minimum number of files to join: ${minCountToJoin}")
  logger.info(s"Policy for starting threads: ${spawnThread}")
  logger.info("Approximate maximum output file size: " +
    maximumSize.map{ _.toString }.getOrElse("unlimited"))
  logger.info(s"Output file compression: ${compressCodec.getOrElse("none")}")
  logger.info(s"""JVM Parameters: ${ManagementFactory.getRuntimeMXBean.getInputArguments.toArray.mkString(",")}""")

  /** Object used by sub-threads to signal to the main thread that they have
    * completed. */
  private[this] val signalQueue = new LinkedBlockingQueue[Int]()

  private[this] val pool: ThreadPoolExecutor =
    new ThreadPoolExecutor(
      maxThreads, maxThreads, 0L, TimeUnit.SECONDS,
      new LinkedBlockingQueue[Runnable](),
      new PackerThreadFactory("FileJoinerThread-"))

  /**
    * How often to print log messages regarding the number of tasks, in
    * seconds.
    */
  val logTaskCountInterval = 5

  // print task count every 5 seconds
  private[this] val logTaskCountThread = Executors.newScheduledThreadPool(1,
    new PackerThreadFactory("LogTaskCounts-"))
  logTaskCountThread.scheduleAtFixedRate(
    new Thread() {
      override def run(): Unit = {
        val active = pool.getActiveCount()
        val completed = pool.getCompletedTaskCount()
        val total = pool.getTaskCount()
        logger.info(s"Directories to scan: ${dirList.size}," +
          s" Total tasks: ${total}," +
          s" Completed tasks: ${completed}," +
          s" Active tasks: ${active}," +
          s" Queued tasks: ${total - active - completed}")
      }
    },
    logTaskCountInterval, logTaskCountInterval, TimeUnit.SECONDS)

  /** Regular expression that matches expected repo file names */
  private[this] val repoFileRegex =
    new Regex("""\A(\d{8}\.\d{2}(?:-PT\d\d?H)?\.).*\Z""")

  logger.info(s"Starting recursive scan of ${dirList.size} director" +
    (if ( 1 == dirList.size ) { "y" } else { "ies" }))

  // Recursively process all directories
  while ( dirList.nonEmpty ) {
    val dir = dirList.head
    dirList = dirList.tail

    logger.trace(s"Scanning directory '${dir}/'")
    val fileMap = new HashMap[String, Set[HPath]] with MultiMap[String, HPath]
    val iter = try {
      fileSystem.listLocatedStatus(dir)
    } catch {
      case NonFatal(e) =>
        // return an empty iterator
        logger.warn(s"Unable to get status of '${dir}/': ${e.getMessage}")
        new RemoteIterator[LocatedFileStatus](){
          def hasNext: Boolean = false
          def next(): LocatedFileStatus = throw new NoSuchElementException()
        }
    }
    while (
      Try {
        if ( !iter.hasNext ) {
          // finished with this directory
          false
        } else {
          // found an entry
          val entry = iter.next()
          if ( entry.isDirectory ) {
            dirList = entry.getPath +: dirList
          } else if ( entry.isFile ) {
            // its a file, check if it matches the regex
            for ( m <- repoFileRegex.findFirstMatchIn(entry.getPath.getName) ) {
              fileMap.addBinding(m.group(1), entry.getPath)
            }
          }
          true
        }
      } match {
        case Success(ok) => ok
        case Failure(e) =>
          // ignore errors stat-ing files
          logger.debug(s"Unable to read directory entry: ${e.toString}")
          true
      }
    ) { /*empty-body*/ }

    if ( threadPerDirectory ) {
      // Create a thread for all files in the directory
      if ( fileMap.size >= minCountToJoin ) {
        pool.execute(DirectoryJob(dir, fileMap))
      }
    } else {
      // Create threads to process the files for each unique basename in this
      // directory
      for (
        (basename, set) <- fileMap
        if set.size >= minCountToJoin
      ) {
        pool.execute(BasenameFilesJob(dir, basename, set))
      }
    }
  }

  // Finished scanning directories.  Wait for joining threads to finish.
  logger.info("Completed recursive directory scan")
  logger.info(
    s"Waiting for ${pool.getTaskCount() - pool.getCompletedTaskCount()}" +
      s" of ${pool.getTaskCount()} tasks to complete...")

  // all tasks are queued; shutdown the thread pool and allow the
  // running/queued tasks to complete
  pool.shutdown()

  // clear the signalQueue of previously completed tasks then wait for the
  // thread pool to terminate
  signalQueue.clear()
  while ( !pool.isTerminated() ) {
    signalQueue.poll(5, TimeUnit.SECONDS)
    //signalQueue.clear()
  }
  logger.debug("All tasks have completed")
  logTaskCountThread.shutdown()
  logTaskCountThread.awaitTermination(1, TimeUnit.SECONDS)

  logger.info("FileJoiner is done")

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
