// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk

import java.io.{OutputStream, DataOutputStream}
import org.apache.hadoop.conf.Configuration

import io.{BufferWriter, LzoOutputStreamBuffer, RawOutputStreamBuffer,
  SnappyOutputStreamBuffer, ZlibOutputStreamBuffer}
import io.BufferUtil.{putBytes, putInt32, putInt64}

import org.cert.netsa.data.net.IPAddress


/**
  * A writer of binary SiLK Bag files.
  *
  * To include a header in the Bag file that specifies the type of the
  * and counter, run `setKeyType()` and/or `setCounterType()` prior to
  * writing the Bag.
  *
  * @example This example reads the contents of "example.bag" and
  * writes it to "copy.bag", where the keys are IP addresses:
  * {{{
  * val in = new java.io.FileInputStream("example.bag")
  * val out = new java.io.FileOutputStream("copy.bag")
  * val bagresult = BagReader.ofInputStream(in)
  * val bag = bagresult match {
  *   case BagResult.IPAddressBag(iter) => iter
  *   case _ => null
  * }
  * val writer = BagWriter.toOutputStream(out)
  * if ( None != bag.keyType ) {
  *   writer.setKeyType(bag.keyType)
  * }
  * if ( None != bag.counterType ) {
  *   writer.setCounterType(bag.counterType)
  * }
  * writer.appendIPAddresses(bag)
  * writer.close()
  * }}}
  *
  * @see [[BagWriter$ the companion object]] for more details
  */
class BagWriter private (
  val out: DataOutputStream,
  val compressionMethod: CompressionMethod,
  val hadoopConfig: Option[Configuration])
{
  /**
    * Set to true once the file's header has been written
    */
  private[this] var headerWritten = false

  /**
    *  The type of the key; set by the `setKeyType()` method.
    */
  private[this] var typeKey: Option[BagDataType] = None

  /**
    *  The type of the counter; set by the `setCounterType()` method.
    */
  private[this] var typeCounter: Option[BagDataType] = None

  /**
    *  Whether the Bag contains IPAddresses or Integers.  Set by the
    *  first call to either appendIPAddresses() or appendIntegers().
    */
  private[this] var containsIPs: Option[Boolean] = None

  /**
    * Length of the key.  Determined by the iterator provided to the
    * first call to append().
    */
  private[this] var keyLength:Int = _

  /**
    * Length of a single record.  Equal to the keyLength plus the
    * counter length (which is always 8).
    */
  private[this] var recordLength: Int = _

  /**
    * The size of the buffer to hold data prior to compressing.
    */
  private[this] val bufferSize = 65536

  /**
    * The buffer to hold data prior to compressing.  The
    * appendIntegers or appendIPAddresses() method fills this buffer.
    */
  private[this] val buffer = Array.ofDim[Byte](bufferSize)

  /**
    * The current offset into the output buffer.
    */
  private[this] var offset = 0

  /**
    * When writing an IPAddress Bag, the value of the previous key.
    * Used to ensure sorted input.
    */
  private[this] var prevKeyIP: IPAddress = _

  /**
    * When writing an Int Bag, the value of the previous key.  Used to
    * ensure sorted input.
    */
  private[this] var prevKeyInt: Int = _

  /**
    * When writing an Int Bag, false once an integer has been written
    * to the stream.  Used with prevKeyInt.
    */
  private[this] var firstKey = true

  /**
    * Create a buffered writer that is capable of compressing the data
    * as it is written.
    */
  private[this] val writer: BufferWriter =
    compressionMethod match {
      case CompressionMethod.NONE => RawOutputStreamBuffer(out)
      case CompressionMethod.ZLIB => ZlibOutputStreamBuffer(out)
      case CompressionMethod.LZO1X => LzoOutputStreamBuffer(hadoopConfig.get, out)
      case CompressionMethod.SNAPPY => SnappyOutputStreamBuffer(hadoopConfig.get, out)
      case _ => throw new SilkDataFormatException("Unrecognized compression method")
    }


  /**
    * A helper function for `setKeyType()` and `setCounterType()` to
    * set either the key or the counter.  If called to set the key
    * type and the current counter type is `None`, sets the counter
    * type to CUSTOM.  Likewise for the key type when the caller sets
    * the counter type.
    */
  private[this] def setTypeKeyCounter(
    key:     Option[BagDataType],
    counter: Option[BagDataType])
  {
    if ( headerWritten ) {
      // FIXME: Error class
      val kc: String = if (None != key) { "key" } else { "counter" }
      throw new SilkDataFormatException("Must set " + kc + " type before writing bag")
    }

    // Set the key's type.  If the caller is setting the counter's
    // type and the key's type has not been set, set it to CUSTOM
    if ( None != key ) {
      typeKey = key
    } else if ( None == typeKey ) {
      typeKey = Option(BagDataType.SKBAG_FIELD_CUSTOM)
    }

    // Set the counter's type.  Similar to setting key's type
    if ( None != counter ) {
      typeCounter = counter
    } else if ( None == typeCounter ) {
      typeCounter = Option(BagDataType.SKBAG_FIELD_CUSTOM)
    }
  }

  /**
    * Verify that the caller does not use a mixture of calls to
    * appendIntegers() and appendIPAddresses().  If this is the first
    * call to this function, set the content type of the Bag.
    *
    * `appendIPs` should be `true` if called from appendIPAddresses()
    * and `false` otherwise
    */
  private[this] def checkSetContents(appendIPs: Boolean)
  {
    containsIPs match {
      case None => containsIPs = Option(appendIPs)
      case _ =>
        if ( containsIPs.get != appendIPs ) {
          throw new NoSuchElementException(
            "May not mix calls to appendIntegers() and appendIPAddresses()")
        }
    }
  }

  /**
    * Writes the SiLK file header to the output stream.
    */
  private[this] def writeHeader(isIPv6: Boolean)
  {
    // convert the key and counter types to numbers
    val keyID: Short = typeKey match {
      case None => BagDataType.SKBAG_FIELD_CUSTOM.value
      case _ => typeKey.get.value
    }
    val counterID: Short = typeCounter match {
      case None => BagDataType.SKBAG_FIELD_CUSTOM.value
      case _ => typeCounter.get.value
    }
    // determine the key length; counter length is always 8
    val keyLen: Short = if ( isIPv6 ) { 16 } else { 4 }
    // determine the record version: 4 if IPv6; 3 otherwise
    val recVers: Short = if ( isIPv6 ) { 4 } else { 3 }

    var h_entries: Vector[HeaderEntry] = Vector(HeaderEntry.EndOfHeaders)
    if ( None != typeKey || isIPv6 ) {
      h_entries = HeaderEntry.Bag(keyID, keyLen, counterID, 8) +: h_entries
    }
    val header = new Header(Header.BigEndian, FileFormat.FT_RWBAG,
      Header.FileVersion, compressionMethod, SilkVersion(0),
      (keyLen + 8).toShort, recVers, h_entries)

    header.writeTo(out)
    out.flush()
    headerWritten = true
  }


  /**
    * Whether the Bag file has been written--that is, whether the
    * `append()` method has been called.
    *
    * @return `true` once the `append()` method has been called
    */
  def wasHheaderWritten: Boolean = headerWritten

  /**
    * Sets the type of the key.  The value is written into the output
    * stream's header.
    *
    * Throws `SilkDataFormatException`//FIXME if called after the
    * file's header has been written.
    */
  def setKeyType(keyType: BagDataType): Unit = {
    setTypeKeyCounter(Option(keyType), None)
  }

  /**
    * Sets the type of the counter.  The value is written into the
    * output stream's header.
    *
    * Throws `SilkDataFormatException`//FIXME if called after the
    * file's header has been written.
    */
  def setCounterType(counterType: BagDataType): Unit = {
    setTypeKeyCounter(None, Option(counterType))
  }


  /**
    * Iterates over the (key, counter) pairs where each key is an
    * [[scala.Int Int]] and writes the values to the output stream as
    * a SiLK Bag.
    *
    * Expects the [[scala.Int Ints]] in the
    * [[scala.collection.Iterator Iterator]] to be in sorted order
    * (numerically ascending).  Throws NoSuchElementException//FIXME
    * if they are not.
    *
    * Writes the file's header if it has not been written yet.  The
    * type of the key and counter may no longer be changed once this
    * function is called.
    *
    * This function may be called successfully multiple times as long
    * as the keys across the various calls are in sorted order.
    *
    * Calls to this function may not be mixed with calls to
    * appendIPAddresses().
    *
    * @throws NoSuchElementException//FIXME if the keys are not in
    * sorted order.
    */
  def appendIntegers(iter: Iterator[(Int, Long)])
  {
    checkSetContents(false)

    if ( !headerWritten ) {
      // Ensure key type is not an IPv6 type
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _ => false
      }
      if ( typeIsIPv6 ) {
        // FIXME: better error class
        throw new NoSuchElementException(
          "May not use an IPv6 key type when the key is an integer")
      }

      // write the file's header
      writeHeader(false)

      keyLength = 4
      recordLength = keyLength + 8
    }

    for ((key, counter) <- iter) {
      if ( key <= prevKeyInt ) {
        if ( !firstKey ) {
          // FIXME: better error code
          throw new NoSuchElementException(
            "Integer keys are unsorted or not unique")
        }
        firstKey = false
      }
      prevKeyInt = key
      if ( bufferSize - offset < 12 ) {
        writer.putBuffer(buffer, offset)
        offset = 0
      }
      putInt32(buffer, offset, key)
      putInt64(buffer, offset + 4, counter)
      offset = offset + 12
    }

    if (offset > 0) {
      writer.putBuffer(buffer, offset)
      offset = 0
    }
    out.flush()
  }

  /**
    * Iterates over the (key, counter) pairs where each key is an
    * [[org.cert.netsa.data.net.IPAddress IPAddresses]] and writes the
    * values to the output stream as a SiLK Bag.
    *
    * Expects the [[org.cert.netsa.data.net.IPAddress IPAddresses]] in
    * the [[scala.collection.Iterator Iterator]] to be in sorted order
    * (numerically ascending).  Throws NoSuchElementException//FIXME
    * if they are not.
    *
    * Expects all [[org.cert.netsa.data.net.IPAddress IPAddresses]] in
    * the [[scala.collection.Iterator Iterator]] to be of the same
    * size; that is, either all are
    * [[org.cert.netsa.data.net.IPv4IPv4Address IPv4Address]] or all are
    * [[org.cert.netsa.data.net.IPv6Address IPv6Address]].  Throws
    * NoSuchElementException//FIXME if they are not.
    *
    * Writes the file's header if it has not been written yet.  The
    * type of the key and counter may no longer be changed once this
    * function is called.
    *
    * This function may be called successfully multiple times as long
    * as the IPAddresses across the various calls are the same size
    * and are in sorted order.
    *
    * Calls to this function may not be mixed with calls to
    * appendIntegers().
    *
    * @throws NoSuchElementException//FIXME if the keys are not in
    * sorted order or if the keys are of different sizes.
    */
  def appendIPAddresses[T <: IPAddress](iter: Iterator[(T, Long)])
  {
    checkSetContents(true)

    if ( !headerWritten ) {
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _ => false
      }

      if ( !iter.hasNext ) {
        // write the bag's header based on the key type
        writeHeader(typeIsIPv6)

        keyLength = if ( typeIsIPv6 ) { 16 } else { 4 }
        recordLength = keyLength + 8

        return
      }

      // get the octet-length of the first IP so we know whether to
      // write an IPv4 or an IPv6 value into the file's header
      val (ip, counter) = iter.next()
      keyLength = ip.toBytes.length
      recordLength = keyLength + 8

      val valueIsIPv6 = keyLength match {
        case 4  => false
        case 16 => true
        case _  => throw new NoSuchElementException("Unexpected IP Address length " + keyLength) // FIXME: better error
      }

      if ( typeIsIPv6 && !valueIsIPv6 ) {
        // FIXME: better error
        throw new NoSuchElementException(
          "May not use an IPv6 key type when the key is an IPv4 address")
      }

      // write the file's header
      writeHeader(valueIsIPv6)

      // append the first IP address and counter
      assert(0 == offset)
      assert(bufferSize > recordLength)

      putBytes(buffer, 0, ip.toBytes, keyLength)
      putInt64(buffer, keyLength, counter)
      offset = offset + recordLength

      // ensure IP addresses are properly ordered
      prevKeyIP = ip
    }

    // handle the key,counter pairs
    for ((ip, counter) <- iter) {
      if ( ip <= prevKeyIP ) {
        // FIXME: better error code
        throw new NoSuchElementException(
          "IPAddress keys are unsorted or not unique")
      }
      prevKeyIP = ip
      val arr = ip.toBytes
      if (arr.length != keyLength) {
        // FIXME: better error code
        throw new NoSuchElementException("IPAddresses are of different sizes")
      }
      if ( bufferSize - offset < recordLength ) {
        writer.putBuffer(buffer, offset)
        offset = 0
      }
      putBytes(buffer, offset, arr, keyLength)
      putInt64(buffer, offset + keyLength, counter)
      offset = offset + recordLength
    }

    if (offset > 0) {
      writer.putBuffer(buffer, offset)
      offset = 0
    }
    out.flush()
  }

  /**
    * Closes the output stream.
    *
    * Writes the SiLK file header to the output stream if it has not
    * been written, writes any buffered records, closes the output
    * stream, and releases resources.
    */
  def close(): Unit = {
    if ( !headerWritten ) {
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6) => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _ => false
      }
      writeHeader(typeIsIPv6)
    }
    if ( offset > 0 ) {
      writer.putBuffer(buffer, offset)
      offset = 0
    }
    writer.end()
    out.close()
  }

}


/**
  * The BagWriter companion object provides support for creating an
  * [[BagWriter]].
  */
object BagWriter {
  /**
    * Creates and returns a writer that iterates over (key, counter)
    * pairs and writes them as a binary SiLK Bag stream to the output
    * stream `s`.  Compresses the output using `compressionMethod`.
    * Some compression methods require that a hadoop configuration be
    * provided.
    *
    * @throws [[java.util.NoSuchElementException]] when a hadoop
    * configuration is required and none is provided.
    */
  def toOutputStream(
    s: OutputStream,
    compressionMethod: CompressionMethod = CompressionMethod.NONE,
    hadoopConfig: Option[Configuration] = None)
      : BagWriter =
  {
    val out = s match {
      case x: DataOutputStream => x
      case y: OutputStream => new DataOutputStream(y)
    }

    new BagWriter(out, compressionMethod, hadoopConfig)
  }

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
