// Copyright 2025 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk

import java.io.{DataOutputStream, OutputStream}
import java.nio.ByteBuffer

import org.cert.netsa.data.net.{IPAddress, IPBlock}

import io.{
  BufferWriter, LzoOutputStreamBuffer, RawOutputStreamBuffer, SnappyOutputStreamBuffer,
  ZlibOutputStreamBuffer
}
import io.BufferUtil.*

/** A writer of binary SiLK IPset files.
  *
  * @example
  *   This example reads the contents of "example.set" and writes it to "copy.set":
  *   {{{
  * val in = new java.io.FileInputStream("example.set")
  * val out = new java.io.FileOutputStream("copy.set")
  * val ipset = IPSetReader.ofInputStream(in)
  * val writer = IPSetWriter.toOutputStream(out)
  * writer.append(ipset)
  * writer.close()
  *   }}}
  *
  * @see [[IPSetWriter$ the companion object]] for more details
  */
class IPSetWriter private (val out: DataOutputStream, val compressionMethod: CompressionMethod) {

  /** Set to true once the file's header has been written */
  private var headerWritten = false

  /** Length of a single IP Address. Determined by the iterator provided to the first call to
    * append().
    */
  private var ipLength: Int = _

  /** Length of a single record. One more than the ipLength to hold the netblock length */
  private var recordLength: Int = _

  /** The size of the buffer to hold data prior to compressing. */
  private val bufferSize = 65536

  /** The buffer to hold data prior to compressing. The append() method fills this buffer. */
  private val buffer = ByteBuffer.allocate(bufferSize)

  /** The current offset into the output buffer. */
  private var offset = 0

  /** The maximum IP address of the previous IPBlock written to the stream. Used to ensure blocks
    * are sorted and do not overlap.
    */
  private var prevBlockMax: IPAddress = _

  /** Object use to write/compress the output */
  private val writer: BufferWriter = compressionMethod match {
    case CompressionMethod.NONE   => RawOutputStreamBuffer(out)
    case CompressionMethod.ZLIB   => ZlibOutputStreamBuffer(out)
    case CompressionMethod.LZO1X  => LzoOutputStreamBuffer(out)
    case CompressionMethod.SNAPPY => SnappyOutputStreamBuffer(out)
    case _ => throw new SilkDataFormatException("Unrecognized compression method")
  }

  /** Writes the SiLK file header to the output stream. */
  private def writeHeader(isIPv6: Boolean): Unit = {
    // first create the IPset header entry
    val h_entries: Vector[HeaderEntry] = Vector(
      HeaderEntry.IPSet(
        0,
        0,
        if (isIPv6) {
          16
        } else {
          4
        },
        0,
        0,
        0
      ),
      HeaderEntry.EndOfHeaders
    )

    // then create and write the header object
    val header = new Header(
      Header.BigEndian, FileFormat.FT_IPSET, Header.FileVersion, compressionMethod, SilkVersion(0),
      1, 4, h_entries
    )
    header.writeTo(out)
    out.flush()
    headerWritten = true
  }

  /** Whether any IPBlocks have been written to the stream--that is, whether the `append()` method
    * has been called with a non-emtpy Iterator.
    *
    * @return `true` once the `append()` method has been called
    */
  def wasHeaderWritten: Boolean = headerWritten

  /** Iterates over the [[org.cert.netsa.data.net.IPBlock IPBlocks]] and appends them to the
    * destination stream.
    *
    * Expects the [[org.cert.netsa.data.net.IPBlock IPBlocks]] in the
    * [[scala.collection.Iterator Iterator]] to be in sorted order (numerically ascending).
    *
    * Expects all [[org.cert.netsa.data.net.IPBlock IPBlocks]] in the
    * [[scala.collection.Iterator Iterator]] to be of the same size; that is, either all are
    * [[org.cert.netsa.data.net.IPv4Block IPv4Block]] or all are
    * [[org.cert.netsa.data.net.IPv6Block IPv6Block]].
    *
    * This function may be called successfully multiple times as long as all IPBlocks have the same
    * size and the IPBlocks across the various calls are in sorted order.
    *
    * @throws java.util.NoSuchElementException if the IPBlock Iterator contains a mix of IPv4
    *   addresses and IPv6 addresses or the IPBlocks are not in sorted order.
    */
  def append[T <: IPBlock](iter: Iterator[T]): Unit = {
    if (!headerWritten) {
      // initialize the IPSet file by the first IPBlock in the
      // Iterator
      if (!iter.hasNext) {
        return
      }

      val block: IPBlock = iter.next()
      ipLength = block.min.toBytes.length
      recordLength = ipLength + 1

      val isIPv6: Boolean = ipLength match {
        case 4  => false
        case 16 => true
        case _  => throw new IllegalArgumentException(s"Unexpected IP Address length $ipLength")
      }

      // write the file's header
      writeHeader(isIPv6)

      // append the first IPBlock to the buffer
      assert(0 == offset)
      assert(bufferSize > recordLength)

      buffer.putBytes(0, block.min.toBytes, ipLength)
      buffer.put(ipLength, block.prefixLength.toByte)
      offset = recordLength

      prevBlockMax = block.max
    }

    // process the IP blocks
    for (block <- iter) {
      require(prevBlockMax < block.min, "IPBlocks are unsorted or overlap")
      prevBlockMax = block.max

      val arr = block.min.toBytes
      require(arr.length == ipLength, "Cannot mix IPv4Blocks and IPv6Blocks in IPSetWriter")

      if (bufferSize - offset < recordLength) {
        writer.putBuffer(buffer.array, offset)
        offset = 0
      }
      buffer.putBytes(offset, arr, ipLength)
      buffer.put(offset + ipLength, block.prefixLength.toByte)
      offset = offset + recordLength
    }

    if (offset > 0) {
      writer.putBuffer(buffer.array, offset)
      offset = 0
    }
    out.flush()
  }

  /** Closes the output stream.
    *
    * Writes the SiLK file header to the output stream if it has not been written, writes any
    * buffered records, closes the output stream, and releases resources.
    */
  def close(): Unit = {
    if (!headerWritten) {
      writeHeader(false)
    }
    if (offset > 0) {
      writer.putBuffer(buffer.array, offset)
      offset = 0
    }
    writer.end()
    out.close()
  }

}

/** The IPSetWriter companion object provides support for creating an [[IPSetWriter]]. */
object IPSetWriter {

  /** Creates and returns a writer that iterates over [[org.cert.netsa.data.net.IPBlock IPBlocks]]
    * and writes them as a binary SiLK IPset stream (compatible with SiLK 3.7.0 and later) to the
    * output stream `s`. Compresses the output using `compressionMethod`. Some compression methods
    * require that a hadoop configuration be provided.
    *
    * @throws java.util.NoSuchElementException when a hadoop configuration is required and none is
    *   provided.
    */
  def toOutputStream(
    s: OutputStream,
    compressionMethod: CompressionMethod = CompressionMethod.NONE
  ): IPSetWriter = {
    val out = s match {
      case x: DataOutputStream => x
      case y: OutputStream     => new DataOutputStream(y)
    }

    new IPSetWriter(out, compressionMethod)
  }

}

// @LICENSE_FOOTER@
//
// Mothra 1.7
//
// Copyright 2025 Carnegie Mellon University.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS
// FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND,
// EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS
// FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL.
// CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM
// PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Licensed under a GNU GPL 2.0-style license, please see LICENSE.txt or contac
// permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited
// distribution.  Please see Copyright notice for non-US Government use and distribution.
//
// This Software includes and/or makes use of Third-Party Software each subject to its own license.
//
// DM24-1649
