// Copyright 2025 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk

import java.io.{DataOutputStream, OutputStream}
import java.nio.ByteBuffer

import org.cert.netsa.data.net.IPAddress

import io.{
  BufferWriter, LzoOutputStreamBuffer, RawOutputStreamBuffer, SnappyOutputStreamBuffer,
  ZlibOutputStreamBuffer
}
import io.BufferUtil.*

/** A writer of binary SiLK Bag files.
  *
  * To include a header in the Bag file that specifies the type of the and counter, run
  * `setKeyType()` and/or `setCounterType()` prior to writing the Bag.
  *
  * @example
  *   This example reads the contents of "example.bag" and writes it to "copy.bag", where the keys
  *   are IP addresses:
  *   {{{
  * val in = new java.io.FileInputStream("example.bag")
  * val out = new java.io.FileOutputStream("copy.bag")
  * val bagresult = BagReader.ofInputStream(in)
  * val bag = bagresult match {
  *   case BagResult.IPAddressBag(iter) => iter
  *   case _ => null
  * }
  * val writer = BagWriter.toOutputStream(out)
  * if ( None != bag.keyType ) {
  *   writer.setKeyType(bag.keyType)
  * }
  * if ( None != bag.counterType ) {
  *   writer.setCounterType(bag.counterType)
  * }
  * writer.appendIPAddresses(bag)
  * writer.close()
  *   }}}
  *
  * @see [[BagWriter$ the companion object]] for more details
  */
class BagWriter private (val out: DataOutputStream, val compressionMethod: CompressionMethod) {

  /** Set to true once the file's header has been written */
  private var headerWritten = false

  /** The type of the key; set by the `setKeyType()` method. */
  private var typeKey: Option[BagDataType] = None

  /** The type of the counter; set by the `setCounterType()` method. */
  private var typeCounter: Option[BagDataType] = None

  /** Whether the Bag contains IPAddresses or Integers. Set by the first call to either
    * appendIPAddresses() or appendIntegers().
    */
  private var containsIPs: Option[Boolean] = None

  /** Length of the key. Determined by the iterator provided to the first call to append(). */
  private var keyLength: Int = _

  /** Length of a single record. Equal to the keyLength plus the counter length (which is always 8).
    */
  private var recordLength: Int = _

  /** The size of the buffer to hold data prior to compressing. */
  private val bufferSize = 65536

  /** The buffer to hold data prior to compressing. The appendIntegers or appendIPAddresses() method
    * fills this buffer.
    */
  private val buffer = ByteBuffer.allocate(bufferSize)

  /** The current offset into the output buffer. */
  private var offset = 0

  /** When writing an IPAddress Bag, the value of the previous key. Used to ensure sorted input. */
  private var prevKeyIP: IPAddress = _

  /** When writing an Int Bag, the value of the previous key. Used to ensure sorted input. */
  private var prevKeyInt: Int = _

  /** When writing an Int Bag, false once an integer has been written to the stream. Used with
    * prevKeyInt.
    */
  private var firstKey = true

  /** Create a buffered writer that is capable of compressing the data as it is written. */
  private val writer: BufferWriter = compressionMethod match {
    case CompressionMethod.NONE   => RawOutputStreamBuffer(out)
    case CompressionMethod.ZLIB   => ZlibOutputStreamBuffer(out)
    case CompressionMethod.LZO1X  => LzoOutputStreamBuffer(out)
    case CompressionMethod.SNAPPY => SnappyOutputStreamBuffer(out)
    case _ => throw new SilkDataFormatException("Unrecognized compression method")
  }

  /** A helper function for `setKeyType()` and `setCounterType()` to set either the key or the
    * counter. If called to set the key type and the current counter type is `None`, sets the
    * counter type to CUSTOM. Likewise for the key type when the caller sets the counter type.
    */
  private def setTypeKeyCounter(key: Option[BagDataType], counter: Option[BagDataType]): Unit = {
    require(
      !headerWritten,
      s"Must set ${if (key.isDefined) "key"
        else "counter"} type before writing bag"
    )

    // Set the key's type.  If the caller is setting the counter's
    // type and the key's type has not been set, set it to CUSTOM
    if (None != key) {
      typeKey = key
    } else if (None == typeKey) {
      typeKey = Option(BagDataType.SKBAG_FIELD_CUSTOM)
    }

    // Set the counter's type.  Similar to setting key's type
    if (None != counter) {
      typeCounter = counter
    } else if (None == typeCounter) {
      typeCounter = Option(BagDataType.SKBAG_FIELD_CUSTOM)
    }
  }

  /** Verify that the caller does not use a mixture of calls to appendIntegers() and
    * appendIPAddresses(). If this is the first call to this function, set the content type of the
    * Bag.
    *
    * `appendIPs` should be `true` if called from appendIPAddresses() and `false` otherwise
    */
  private def checkSetContents(appendIPs: Boolean): Unit = {
    containsIPs match {
      case None => containsIPs = Option(appendIPs)
      case _ => if (containsIPs.get != appendIPs) {
          throw new NoSuchElementException(
            "May not mix calls to appendIntegers() and appendIPAddresses()"
          )
        }
    }
  }

  /** Writes the SiLK file header to the output stream. */
  private def writeHeader(isIPv6: Boolean): Unit = {
    // convert the key and counter types to numbers
    val keyID: Short = typeKey match {
      case None => BagDataType.SKBAG_FIELD_CUSTOM.value
      case _    => typeKey.get.value
    }
    val counterID: Short = typeCounter match {
      case None => BagDataType.SKBAG_FIELD_CUSTOM.value
      case _    => typeCounter.get.value
    }
    // determine the key length; counter length is always 8
    val keyLen: Short =
      if (isIPv6) {
        16
      } else {
        4
      }
    // determine the record version: 4 if IPv6; 3 otherwise
    val recVers: Short =
      if (isIPv6) {
        4
      } else {
        3
      }

    var h_entries: Vector[HeaderEntry] = Vector(HeaderEntry.EndOfHeaders)
    if (None != typeKey || isIPv6) {
      h_entries = HeaderEntry.Bag(keyID, keyLen, counterID, 8) +: h_entries
    }
    val header = new Header(
      Header.BigEndian,
      FileFormat.FT_RWBAG,
      Header.FileVersion,
      compressionMethod,
      SilkVersion(0),
      (keyLen + 8).toShort,
      recVers,
      h_entries
    )

    header.writeTo(out)
    out.flush()
    headerWritten = true
  }

  /** Whether the Bag file has been written--that is, whether the `append()` method has been called.
    *
    * @return `true` once the `append()` method has been called
    */
  def wasHheaderWritten: Boolean = headerWritten

  /** Sets the type of the key. The value is written into the output stream's header.
    *
    * @throws java.lang.IllegalArgumentException if called after the file's header has been written
    */
  def setKeyType(keyType: BagDataType): Unit = {
    setTypeKeyCounter(Option(keyType), None)
  }

  /** Sets the type of the counter. The value is written into the output stream's header.
    *
    * @throws java.lang.IllegalArgumentException if called after the file's header has been written.
    */
  def setCounterType(counterType: BagDataType): Unit = {
    setTypeKeyCounter(None, Option(counterType))
  }

  /** Iterates over the (key, counter) pairs where each key is an [[scala.Int Int]] and writes the
    * values to the output stream as a SiLK Bag.
    *
    * Expects the [[scala.Int Ints]] in the [[scala.collection.Iterator Iterator]] to be in sorted
    * order (numerically ascending).
    *
    * Writes the file's header if it has not been written yet. The type of the key and counter may
    * no longer be changed once this function is called.
    *
    * This function may be called successfully multiple times as long as the keys across the various
    * calls are in sorted order.
    *
    * Calls to this function may not be mixed with calls to appendIPAddresses().
    */
  def appendIntegers(iter: Iterator[(Int, Long)]): Unit = {
    checkSetContents(false)

    if (!headerWritten) {
      // Ensure key type is not an IPv6 type
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6)   => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _                                      => false
      }
      require(!typeIsIPv6, "May not use an IPv6 key type when the key is an integer")

      // write the file's header
      writeHeader(false)

      keyLength = 4
      recordLength = keyLength + 8
    }

    for ((key, counter) <- iter) {
      if (key <= prevKeyInt) {
        require(firstKey, "Integer keys are unsorted or not unique")
        firstKey = false
      }
      prevKeyInt = key
      if (bufferSize - offset < 12) {
        writer.putBuffer(buffer.array, offset)
        offset = 0
      }
      buffer.putInt(offset, key)
      buffer.putLong(offset + 4, counter)
      offset = offset + 12
    }

    if (offset > 0) {
      writer.putBuffer(buffer.array, offset)
      offset = 0
    }
    out.flush()
  }

  /** Iterates over the (key, counter) pairs where each key is an
    * [[org.cert.netsa.data.net.IPAddress IPAddresses]] and writes the values to the output stream
    * as a SiLK Bag.
    *
    * Expects the [[org.cert.netsa.data.net.IPAddress IPAddresses]] in the
    * [[scala.collection.Iterator Iterator]] to be in sorted order (numerically ascending).
    *
    * Expects all [[org.cert.netsa.data.net.IPAddress IPAddresses]] in the
    * [[scala.collection.Iterator Iterator]] to be of the same size; that is, either all are
    * [[org.cert.netsa.data.net.IPv4Address IPv4Address]] or all are
    * [[org.cert.netsa.data.net.IPv6Address IPv6Address]].
    *
    * Writes the file's header if it has not been written yet. The type of the key and counter may
    * no longer be changed once this function is called.
    *
    * This function may be called successfully multiple times as long as the IPAddresses across the
    * various calls are the same size and are in sorted order.
    *
    * Calls to this function may not be mixed with calls to appendIntegers().
    */
  def appendIPAddresses[T <: IPAddress](iter: Iterator[(T, Long)]): Unit = {
    checkSetContents(true)

    if (!headerWritten) {
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6)   => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _                                      => false
      }

      if (!iter.hasNext) {
        // write the bag's header based on the key type
        writeHeader(typeIsIPv6)

        keyLength =
          if (typeIsIPv6) {
            16
          } else {
            4
          }
        recordLength = keyLength + 8

        return
      }

      // get the octet-length of the first IP so we know whether to
      // write an IPv4 or an IPv6 value into the file's header
      val (ip, counter) = iter.next()
      keyLength = ip.toBytes.length
      recordLength = keyLength + 8

      val valueIsIPv6 = keyLength match {
        case 4  => false
        case 16 => true
        case _  => throw new IllegalArgumentException(s"Unexpected IP Address length $keyLength")
      }

      require(
        valueIsIPv6 || !typeIsIPv6,
        "May not use an IPv6 key type when the key is an IPv4 address"
      )

      // write the file's header
      writeHeader(valueIsIPv6)

      // append the first IP address and counter
      assert(0 == offset)
      assert(bufferSize > recordLength)

      buffer.putBytes(0, ip.toBytes, keyLength)
      buffer.putLong(keyLength, counter)
      offset = offset + recordLength

      // ensure IP addresses are properly ordered
      prevKeyIP = ip
    }

    // handle the key,counter pairs
    for ((ip, counter) <- iter) {
      require(ip > prevKeyIP, "IPAddress keys are unsorted or not unique")
      prevKeyIP = ip
      val arr = ip.toBytes
      require(arr.length == keyLength, "IPAddresses are of different sizes")
      if (bufferSize - offset < recordLength) {
        writer.putBuffer(buffer.array, offset)
        offset = 0
      }
      buffer.putBytes(offset, arr, keyLength)
      buffer.putLong(offset + keyLength, counter)
      offset = offset + recordLength
    }

    if (offset > 0) {
      writer.putBuffer(buffer.array, offset)
      offset = 0
    }
    out.flush()
  }

  /** Closes the output stream.
    *
    * Writes the SiLK file header to the output stream if it has not been written, writes any
    * buffered records, closes the output stream, and releases resources.
    */
  def close(): Unit = {
    if (!headerWritten) {
      val typeIsIPv6 = typeKey match {
        case Some(BagDataType.SKBAG_FIELD_SIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_DIPv6)    => true
        case Some(BagDataType.SKBAG_FIELD_NHIPv6)   => true
        case Some(BagDataType.SKBAG_FIELD_ANY_IPv6) => true
        case _                                      => false
      }
      writeHeader(typeIsIPv6)
    }
    if (offset > 0) {
      writer.putBuffer(buffer.array, offset)
      offset = 0
    }
    writer.end()
    out.close()
  }

}

/** The BagWriter companion object provides support for creating an [[BagWriter]]. */
object BagWriter {

  /** Creates and returns a writer that iterates over (key, counter) pairs and writes them as a
    * binary SiLK Bag stream to the output stream `s`. Compresses the output using
    * `compressionMethod`.
    */
  def toOutputStream(
    s: OutputStream,
    compressionMethod: CompressionMethod = CompressionMethod.NONE
  ): BagWriter = {
    val out = s match {
      case x: DataOutputStream => x
      case y: OutputStream     => new DataOutputStream(y)
    }

    new BagWriter(out, compressionMethod)
  }

}

// @LICENSE_FOOTER@
//
// Mothra 1.7
//
// Copyright 2025 Carnegie Mellon University.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS
// FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND,
// EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS
// FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL.
// CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM
// PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Licensed under a GNU GPL 2.0-style license, please see LICENSE.txt or contac
// permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited
// distribution.  Please see Copyright notice for non-US Government use and distribution.
//
// This Software includes and/or makes use of Third-Party Software each subject to its own license.
//
// DM24-1649
