// Copyright 2025 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk

import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream}
import java.time.Instant

/*
GENERIC HEADER + HEADER START
0       1       2       3       4       5       6       7
|magic1 |magic2 |magic3 |magic4 |bigEndi|type   |version|compMet|
(check magic number now, and if version < 16)
(header ends with generic header if version < 16, otherwise header start)
|silk_version                   |rec_size       |rec_version    |
(above are in network byte order, regardless of endianness in byte 4)

(header entries follow, beginning with:)
0       1       2       3       4       5       6       7
|hes_id                         |hes_len                        |
(in network byte order)
(hes_len includes the length of the first 8 bytes read)

hes_id == 0 => end of headers
 */

/** A SiLK file header, including contained header entries. Supports only "new-style" header format
  * (SiLK versions 1.0+).
  *
  * @param fileFlags The bits encoding file flags. Currently only whether the file is big-endian.
  * @see Header.isBigEndian
  * @param fileFormat The SiLK file format contained within this file.
  * @param fileVersion The SiLK file version--specifically the version of the header format.
  * @param compressionMethod The compression method used by data in this file.
  * @param silkVersion The version of SiLK used to create this file.
  * @param recordSize The size of individual (uncompressed) records in this file.
  * @param recordVersion The record version of the file format.
  * @param headerEntries Sequence of additional extensible header records of various types.
  */
case class Header(
  val fileFlags: Byte,
  val fileFormat: FileFormat,
  val fileVersion: Byte,
  val compressionMethod: CompressionMethod,
  // old-style "generic header" ends here
  // not going to support for now. see skheader-legacy.c:skHeaderLegacyDispatch
  val silkVersion: SilkVersion,
  val recordSize: Short,
  val recordVersion: Short,
  val headerEntries: IndexedSeq[HeaderEntry]
) {
  if (fileVersion != Header.FileVersion) {
    throw new SilkDataFormatException(s"Unsupported fileVersion $fileVersion")
  }
  if (fileFlags < 0) {
    throw new SilkDataFormatException(s"Unsupported fileFlags $fileFlags")
  }
  if (recordSize <= 0) {
    throw new SilkDataFormatException(s"Unsupported recordSize $recordSize")
  }
  if (recordVersion < 0) {
    throw new SilkDataFormatException(s"Unsupported recordVersion $recordVersion")
  }

  /** Optional base start time, in milliseconds since the UNIX epoch. Times in this packed file are
    * expressed as a delta from this base start time.
    */
  def startTimeOffset: Option[Long] =
    headerEntries.collectFirst {
      case entry: HeaderEntry.PackedFile => entry.startTime
    }

  /** Optional SiLK flow type for all flows in this packed file. */
  def flowTypeId: Option[FlowType] =
    headerEntries.collectFirst {
      case entry: HeaderEntry.PackedFile => entry.flowtypeId
    }

  /** Optional SiLK sensor ID for all flows in this packed file. */
  def sensorId: Option[Sensor] =
    headerEntries.collectFirst {
      case entry: HeaderEntry.PackedFile => entry.sensorId
    }

  /** Optional probe name recorded with SiLK file. */
  def probeName: Option[String] =
    headerEntries.collectFirst {
      case entry: HeaderEntry.ProbeName => entry.probeName
    }

  /** Command-lines used to produce this SiLK file, if any. */
  def invocations: Seq[String] =
    headerEntries.collect {
      case entry: HeaderEntry.Invocation => entry.commandLine
    }

  /** Annotations made on this SiLK file, if any. */
  def annotations: Seq[String] =
    headerEntries.collect {
      case entry: HeaderEntry.Annotation => entry.annotation
    }

  // Cache versions of these for fast access by the decoding code
  private[silk] val packedStartTime = Instant.ofEpochMilli(startTimeOffset.getOrElse(0L))
  private[silk] val packedFlowTypeId = flowTypeId.getOrElse(FlowType(0))
  private[silk] val packedSensorId = sensorId.getOrElse(Sensor(0))

  /** True if data within the records of this file are stored in big-endian (MSB first) format. */
  def isBigEndian: Boolean = (fileFlags & 0x01) == Header.BigEndian

  /** Writes the header to the provided output stream. */
  def writeTo(outputStream: OutputStream): Unit = {
    val out = outputStream match {
      case out: DataOutputStream => out
      case outputStream          => new DataOutputStream(outputStream)
    }
    writeHeader(out)
  }

  /** Writes the header to the provided output stream. */
  private def writeHeader(out: DataOutputStream): Unit = {
    out.writeInt(Header.magicNumber)
    out.writeByte(fileFlags.toInt)
    out.writeByte(fileFormat.toInt)
    out.writeByte(fileVersion.toInt)
    out.writeByte(compressionMethod.toInt)
    out.writeInt(silkVersion.toInt)
    out.writeShort(recordSize.toInt)
    out.writeShort(recordVersion.toInt)
    writeHeaderEntries(out)
  }

  /** Write the header entries to the specified output stream. */
  private def writeHeaderEntries(out: DataOutputStream): Unit = {
    // offset into the file
    var len: Int = out.size
    assert(16 == len)

    for (h_entry <- headerEntries) {
      h_entry match {
        case HeaderEntry.EndOfHeaders => {
          // pad this header entry to make the length of the entire
          // header an integer multiple of the record size
          val pad: Int = {
            val extra = (len + 8) % recordSize
            if (0 == extra) {
              0
            } else {
              recordSize - extra
            }
          }
          val sz = 8 + pad
          len = len + sz
          out.writeInt(0)
          out.writeInt(sz)
          if (pad > 0) {
            out.write(Array[Byte]().padTo(pad, 0.toByte), 0, pad)
          }
        }
        case h: HeaderEntry.PackedFile => {
          val sz = 8 + 8 + 4 + 4
          len = len + sz
          out.writeInt(1)
          out.writeInt(sz)
          out.writeLong(h.startTime)
          out.writeInt(h.flowtypeId.toInt)
          out.writeInt(h.sensorId.toInt)
        }
        case h: HeaderEntry.Invocation => {
          val sz = 8 + h.commandLine.size + 1
          len = len + sz
          out.writeInt(2)
          out.writeInt(sz)
          out.writeBytes(h.commandLine)
          out.writeByte(0)
        }
        case h: HeaderEntry.Annotation => {
          val sz = 8 + h.annotation.size + 1
          len = len + sz
          out.writeInt(3)
          out.writeInt(sz)
          out.writeBytes(h.annotation)
          out.writeByte(0)
        }
        case h: HeaderEntry.ProbeName => {
          val sz = 8 + h.probeName.size + 1
          len = len + sz
          out.writeInt(4)
          out.writeInt(sz)
          out.writeBytes(h.probeName)
          out.writeByte(0)
        }
        case h: HeaderEntry.PrefixMap => {
          val sz = 8 + 4 + h.mapName.size + 1
          len = len + 8 + sz
          out.writeInt(5)
          out.writeInt(sz)
          out.writeInt(h.version)
          out.writeBytes(h.mapName)
          out.writeByte(0)
        }
        case h: HeaderEntry.Bag => {
          val sz = 8 + 4 * 2
          len = len + sz
          out.writeInt(6)
          out.writeInt(sz)
          out.writeShort(h.keyType.toInt)
          out.writeShort(h.keyLength.toInt)
          out.writeShort(h.counterType.toInt)
          out.writeShort(h.counterLength.toInt)
        }
        case h: HeaderEntry.IPSet => {
          val sz = 8 + 6 * 4
          len = len + sz
          out.writeInt(7)
          out.writeInt(sz)
          out.writeInt(h.childNode)
          out.writeInt(h.leafCount)
          out.writeInt(h.leafSize)
          out.writeInt(h.nodeCount)
          out.writeInt(h.nodeSize)
          out.writeInt(h.rootIndex)
        }
        case h: HeaderEntry.Unknown => {
          val sz = 8 + h.data.size
          len = len + sz
          out.writeInt(h.id)
          out.writeInt(sz)
          out.write(h.data, 0, h.data.size)
        }
      }

      assert(out.size == len)
    }
  }
}

object Header {

  /** The value of fileVersion for all SiLK files supported by Scala. */
  val FileVersion: Byte = 16

  /** The value for fileFlags that indicates the values in the file are represented in network byte
    * order (big endian).
    */
  val BigEndian: Byte = 1

  /** The value for fileFlags that indicates the values in the file are represented in VAX/Intel
    * byte order (little endian).
    */
  val LittleEndian: Byte = 0

  /** The first four bytes of all binary files created by SiLK */
  private[Header] val magicNumber: Int = 0xdeadbeef

  /** Reads and returns a SiLK file header from the provided input stream.
    * @throws SilkDataFormatException if the header is malformed
    */
  def readFrom(inputStream: InputStream): Header = {
    val in = inputStream match {
      case in: DataInputStream => in
      case inputStream         => new DataInputStream(inputStream)
    }
    readHeader(in)
  }

  /** Reads and returns the header, including header entries, from the provided input stream.
    * @throws SilkDataFormatException if the header is malformed
    */
  private def readHeader(in: DataInputStream): Header = {
    val magic = in.readInt()
    if (magic != magicNumber) {
      throw new SilkDataFormatException("Magic number does not match")
    }
    val fileFlags = in.readByte()
    val fileFormat = FileFormat(in.readByte())
    val fileVersion = in.readByte()
    val compressionMethod = CompressionMethod(in.readByte())
    if (fileVersion != FileVersion) {
      throw new SilkDataFormatException(
        s"File version too ${if (fileVersion < FileVersion) "old"
          else "new"}: $fileVersion"
      )
    }
    val silkVersion = SilkVersion(in.readInt())
    val recordSize = in.readShort()
    val recordVersion = in.readShort()
    val headerEntries = readHeaderEntries(in)
    new Header(
      fileFlags, fileFormat, fileVersion, compressionMethod, silkVersion, recordSize, recordVersion,
      headerEntries
    )
  }

  /** Reads and returns a sequence of header entries from the provided input stream. */
  private def readHeaderEntries(in: DataInputStream): IndexedSeq[HeaderEntry] = {
    Iterator
      .continually {
        readHeaderEntry(in)
      }
      .takeWhile {
        _ != HeaderEntry.EndOfHeaders
      }
      .toIndexedSeq
  }

  /** Reads and returns a single individual header entry from the provided input stream. */
  private def readHeaderEntry(in: DataInputStream): HeaderEntry = {
    val headerEntryId = in.readInt()
    val headerEntryLength = in.readInt()
    headerEntryId match {
      case 0 => {
        // End of header marker
        in.skipBytes(headerEntryLength - 8)
        HeaderEntry.EndOfHeaders
      }
      case 1 => {
        // Packed file (Data repository file) info
        val startTime = in.readLong()
        val flowtypeId = FlowType(in.readInt().toByte)
        val sensorId = Sensor(in.readInt().toShort)
        HeaderEntry.PackedFile(startTime, flowtypeId, sensorId)
      }
      case 2 => {
        // A single command line invocation
        val buf: Array[Byte] = Array.ofDim(headerEntryLength - 8)
        in.readFully(buf)
        val commandLine = new String(buf.map(b => (b & 0xff).toChar))
        HeaderEntry.Invocation(commandLine)
      }
      case 3 => {
        // A single note (or annotation)
        val buf: Array[Byte] = Array.ofDim(headerEntryLength - 8)
        in.readFully(buf)
        val annotation = new String(buf.map(b => (b & 0xff).toChar))
        HeaderEntry.Annotation(annotation)
      }
      case 4 => {
        // The probe name from files created by flowcap
        val buf: Array[Byte] = Array.ofDim(headerEntryLength - 8)
        in.readFully(buf)
        val probeName = new String(buf.map(b => (b & 0xff).toChar))
        HeaderEntry.ProbeName(probeName)
      }
      case 5 => {
        // A prefix map entry containing a version and the map's name
        val version = in.readInt()
        val buf: Array[Byte] = Array.ofDim(headerEntryLength - 8 - 4)
        in.readFully(buf)
        // The entry has extra characters after the map name; when
        // building the string, stop at first \0.
        val end = buf.indexWhere(b => b == 0)
        val mapName = new String(buf.take(end).map(b => (b & 0xff).toChar))
        HeaderEntry.PrefixMap(version, mapName)
      }
      case 6 => {
        // A Bag header entry
        val keyType = in.readShort()
        val keyLength = in.readShort()
        val counterType = in.readShort()
        val counterLength = in.readShort()
        HeaderEntry.Bag(keyType, keyLength, counterType, counterLength)
      }
      case 7 => {
        // An IPSet header entry
        val childNode = in.readInt()
        val leafCount = in.readInt()
        val leafSize = in.readInt()
        val nodeCount = in.readInt()
        val nodeSize = in.readInt()
        val rootIndex = in.readInt()
        HeaderEntry.IPSet(childNode, leafCount, leafSize, nodeCount, nodeSize, rootIndex)
      }
      case _ => {
        val buf: Array[Byte] = Array.ofDim(headerEntryLength - 8)
        in.readFully(buf)
        HeaderEntry.Unknown(headerEntryId, buf)
      }
    }
  }
}

// @LICENSE_FOOTER@
//
// Mothra 1.7
//
// Copyright 2025 Carnegie Mellon University.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS
// FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND,
// EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS
// FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL.
// CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM
// PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Licensed under a GNU GPL 2.0-style license, please see LICENSE.txt or contac
// permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited
// distribution.  Please see Copyright notice for non-US Government use and distribution.
//
// This Software includes and/or makes use of Third-Party Software each subject to its own license.
//
// DM24-1649
