// Copyright 2025 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk
package io

import java.lang.{Byte => JByte, Integer => JInt, Short => JShort}
import java.nio.{ByteBuffer, ByteOrder}
import java.time.{Duration, Instant}

import org.cert.netsa.data.net.{
  IPAddress, IPv4Address, IPv6Address, Port, Protocol, SNMPInterface, TCPFlags
}

private[silk] object BufferUtil {

  /** A set of additional utility routines for working with binary input stored in a ByteArray. */
  implicit class BufferOps(val buffer: ByteBuffer) {

    /** Fetches a big-endian 24-bit integer value from the given offset. */
    def getInt24BE(offset: Int): Int =
      (buffer.get(offset + 0) & 0xff) << 16 |
        (buffer.get(offset + 1) & 0xff) << 8 |
        (buffer.get(offset + 2) & 0xff)

    /** Fetches a 24-bit integer value from the given offset, swapping bytes if needed. */
    def getInt24(offset: Int): Int =
      if (buffer.order == ByteOrder.LITTLE_ENDIAN) {
        (buffer.get(offset + 0) & 0xff) |
          (buffer.get(offset + 1) & 0xff) << 8 |
          (buffer.get(offset + 2) & 0xff) << 16
      } else {
        (buffer.get(offset + 0) & 0xff) << 16 |
          (buffer.get(offset + 1) & 0xff) << 8 |
          (buffer.get(offset + 2) & 0xff)
      }

    /** Fetches an array of bytes from the given offset. */
    def getBytes(offset: Int, length: Int): Array[Byte] = {
      val result: Array[Byte] = Array.ofDim(length)
      buffer.position(offset)
      buffer.get(result)
      result
    }

    /** Fetches a 64-bit time expressed as milliseconds since the UNIX epoch from the given offset,
      * swapping bytes if needed, and returns it as a [[java.time.Instant]].
      */
    def getTimeMillis(offset: Int): Instant = Instant.ofEpochMilli(buffer.getLong(offset))

    /** Fetches a 64-bit time expressed as nanoseconds since the UNIX epoch from the given offset,
      * swapping bytes if needed, and returns it as a [[java.time.Instant]].
      */
    def getTimeNanos(offset: Int): Instant = Instant.ofEpochSecond(0, buffer.getLong(offset))

    /** Fetches a 32-bit time expressed as seconds since the UNIX epoch from the given offset,
      * swapping bytes if needed.
      */
    def getTimeSecs(offset: Int): Instant =
      Instant.ofEpochSecond(JInt.toUnsignedLong(buffer.getInt(offset)))

    /** Fetches a 32-bit duration expressed in milliseconds from the * given offset, swapping bytes
      * if needed.
      */
    def getElapsedMillis(offset: Int): Duration =
      Duration.ofMillis(JInt.toUnsignedLong(buffer.getInt(offset)))

    /** Fetches a 16-bit duration expressed in seconds from the given offset, swapping bytes if
      * needed.
      */
    def getElapsedSecs16(offset: Int): Duration =
      Duration.ofSeconds(JShort.toUnsignedLong(buffer.getShort(offset)))

    /** Fetches a 32-bit duration expressed in seconds from the given offset, swapping bytes if
      * needed.
      */
    def getElapsedSecs32(offset: Int): Duration =
      Duration.ofSeconds(JInt.toUnsignedLong(buffer.getInt(offset)))

    /** Fetches a 16-bit integer port value from the given offset, swapping bytes if needed. */
    def getPort(offset: Int): Port = Port(buffer.getShort(offset))

    /** Fetches an 8-bit integer protocol value from the given offset. */
    def getProtocol(offset: Int): Protocol = Protocol(buffer.get(offset))

    /** Fetches an 8-bit integer SiLK flow type value from the given offset. */
    def getFlowType(offset: Int): FlowType = FlowType(buffer.get(offset))

    /** Fetches an 8-bit integer SiLK sensor value from the given offset. */
    def getSensor8(offset: Int): Sensor = Sensor(JByte.toUnsignedInt(buffer.get(offset)).toShort)

    /** Fetches a 16-bit integer SiLK sensor value from the given offset, swapping bytes if needed.
      */
    def getSensor16(offset: Int): Sensor = Sensor(buffer.getShort(offset))

    /** Fetches an 8-bit TCP flag value from the given offset. */
    def getTCPFlags(offset: Int): TCPFlags = TCPFlags(buffer.get(offset))

    /** Fetches an 8-bit SiLK TCP state value from the given offset. */
    def getTCPState(offset: Int): TCPState = TCPState(buffer.get(offset))

    /** Fetches a 16-bit port number representing application from the given offset, swapping bytes
      * if needed.
      */
    def getApplication(offset: Int): Port = Port(buffer.getShort(offset))

    /** Fetches a 16-bit integer representing a memo value from the given offset, swapping bytes if
      * needed.
      */
    def getMemo(offset: Int): Short = buffer.getShort(offset)

    /** Fetches an 8-bit integer representing an SNMP interface from the given offset. */
    def getSNMPInterface8(offset: Int): SNMPInterface =
      SNMPInterface(JByte.toUnsignedInt(buffer.get(offset)))

    /** Fetches a 16-bit integer representing an SNMP interface from the given offset, swapping
      * bytes if needed.
      */
    def getSNMPInterface16(offset: Int): SNMPInterface =
      SNMPInterface(JShort.toUnsignedInt(buffer.getShort(offset)))

    /** Fetches a 32-bit integer representing an SNMP interface from the given offset, swapping
      * bytes if needed.
      */
    def getSNMPInterface32(offset: Int): SNMPInterface = SNMPInterface(buffer.getInt(offset))

    /** Fetches a 24-bit integer representing a packet count from the given offset, swapping bytes
      * if needed.
      */
    def getPacketCount24(offset: Int): Long = buffer.getInt24(offset).toLong

    /** Fetches a 32-bit integer representing a packet count from the given offset, swapping bytes
      * if needed.
      */
    def getPacketCount32(offset: Int): Long = JInt.toUnsignedLong(buffer.getInt(offset))

    /** Fetches a 64-bit integer representing a packet count from the given offset, swapping bytes
      * if needed.
      */
    def getPacketCount64(offset: Int): Long = buffer.getLong(offset)

    /** Fetches a 32-bit integer representing a byte count from the given offset, swapping bytes if
      * needed.
      */
    def getByteCount32(offset: Int): Long = JInt.toUnsignedLong(buffer.getInt(offset))

    /** Fetches a 64-bit integer representing a byte count from the given offset, swapping bytes if
      * needed.
      */
    def getByteCount64(offset: Int): Long = buffer.getLong(offset)

    /** Fetches a 128-bit IPv6 address if isIPv6 is true, or the last 32 bits as an IPv4 address
      * otherwise, always in network byte order.
      */
    def getIPAddress(offset: Int, isIPv6: Boolean): IPAddress =
      if (isIPv6) {
        IPv6Address(buffer.getBytes(offset, 16))
      } else {
        IPv4Address(buffer.getBytes(offset + 12, 4))
      }

    /** Fetches a 128-bit IPv6 address from the given offset (always in network byte order.) */
    def getIPv6Address(offset: Int): IPv6Address = IPv6Address(buffer.getBytes(offset, 16))

    /** Fetches a 32-bit IPv4 address from the given offset, swapping bytes if needed. */
    def getIPv4Address(offset: Int): IPv4Address = IPv4Address(buffer.getInt(offset))

    // Decode the nanosecond startTime, endTime, protocol, TCP state, and flags for formats that
    // encode these using state_flag_stime and rflag_etime.

    //  uint64_t      state_flag_stime;//  0- 7
    //  // uint64_t     tcp_state : 8; //        TCP state machine info
    //  // uint64_t     pro_iflags: 8; //        is_tcp==0: Protocol; else:
    //                                 //          EXPANDED==0:TCPflags/ALL pkts
    //                                 //          EXPANDED==1:TCPflags/1st pkt
    //  // uint64_t     unused    : 5; //        Reserved
    //  // uint64_t     is_tcp    : 1; //        1 if FLOW is TCP; 0 otherwise
    //  // uint64_t     stime     :42; //        Start time:nsec offset from hour
    //
    //  uint64_t      rflag_etime;     //  8-15
    //  // uint64_t     rest_flags: 8; //        is_tcp==0: Empty; else
    //                                 //          EXPANDED==0:Empty
    //                                 //          EXPANDED==1:TCPflags/!1st pkt
    //  // uint64_t     etime     :56; //        End time:nsec offset from hour

    def decodeNanoTimesFlagsProto(
      offset: Int,
      header: Header
    ): (Instant, Instant, Protocol, TCPState, TCPFlags, TCPFlags, TCPFlags) = {
      val stateFlagStartTime = buffer.getLong(offset)
      val rflagEndTime = buffer.getLong(offset + 8)
      val isTcp = (stateFlagStartTime & 0x0000040000000000L) != 0
      val startTimeOffsetNanos = stateFlagStartTime & 0x000003ffffffffffL
      val endTimeOffsetNanos = rflagEndTime & 0x00ffffffffffffffL

      val tcpState = TCPState((stateFlagStartTime >>> 56).toByte)
      val startTime = header.packedStartTime.plusNanos(startTimeOffsetNanos)
      val endTime = header.packedStartTime.plusNanos(endTimeOffsetNanos)

      if (!isTcp) {
        val protocol = Protocol((stateFlagStartTime >>> 48).toByte)
        (startTime, endTime, protocol, tcpState, TCPFlags(0), TCPFlags(0), TCPFlags(0))
      } else if (tcpState.expandedFlags) {
        val initFlags = TCPFlags((stateFlagStartTime >>> 48).toByte)
        val restFlags = TCPFlags((rflagEndTime >>> 56).toByte)
        val flags = initFlags | restFlags
        (startTime, endTime, Protocol.TCP, tcpState, flags, initFlags, restFlags)
      } else {
        val flags = TCPFlags((stateFlagStartTime >>> 48).toByte)
        (startTime, endTime, Protocol.TCP, tcpState, flags, TCPFlags(0), TCPFlags(0))
      }
    }

    // Decode the startTime, protocol, TCP state, and flags for formats that encode these using
    // rflag_stime, proto_iflags, and tcp_state:

    //  uint32_t      rflag_stime;     //  0- 3
    //  // uint32_t     rest_flags: 8; //        is_tcp==0: Empty; else
    //                                 //          EXPANDED==0:Empty
    //                                 //          EXPANDED==1:TCPflags/!1st pkt
    //  // uint32_t     is_tcp    : 1; //        1 if FLOW is TCP; 0 otherwise
    //  // uint32_t     unused    : 1; //        Reserved
    //  // uint32_t     stime     :22; //        Start time:msec offset from hour
    //
    //  uint8_t       proto_iflags;    //  4     is_tcp==0: Protocol; else:
    //                                 //          EXPANDED==0:TCPflags/ALL pkts
    //                                 //          EXPANDED==1:TCPflags/1st pkt
    //  uint8_t       tcp_state;       //  5     TCP state machine info

    /** Decodes bit-bashed information providing the start time, protocol, TCP state, and flags for
      * SiLK record formats that use this representation.
      */
    def decodeTimesFlagsProto(
      offset: Int,
      header: Header
    ): (Instant, Protocol, TCPState, TCPFlags, TCPFlags, TCPFlags) = {
      val rflagStartTime = buffer.getInt(offset)
      val tcpState = buffer.getTCPState(offset + 5)
      val startTimeOffset = rflagStartTime & 0x003fffff
      val isTCP = (rflagStartTime & 0x00800000) != 0

      val startTime = header.packedStartTime.plusMillis(startTimeOffset.toLong)
      if (!isTCP) {
        (startTime, buffer.getProtocol(offset + 4), tcpState, TCPFlags(0), TCPFlags(0), TCPFlags(0))
      } else if (tcpState.expandedFlags) {
        val restFlags = (rflagStartTime >> 24).toByte
        val initFlags = buffer.get(offset + 4)
        val flags = (restFlags | initFlags).toByte
        (
          startTime, Protocol.TCP, tcpState, TCPFlags(flags), TCPFlags(initFlags),
          TCPFlags(restFlags)
        )
      } else {
        (
          startTime,
          Protocol.TCP,
          tcpState,
          buffer.getTCPFlags(offset + 4),
          TCPFlags(0),
          TCPFlags(0)
        )
      }
    }

    // Decode the sTime, elapsed, packets, bytes, protocol, TCP flags, TCP state, and application
    // fields for formats that encode these using stime_bb1, bb2_elapsed, pro_flg_pkts, tcp_state,
    // rest_flags, and application:

    // in the following: EXPANDED == ((tcp_state & SK_TCPSTATE_EXPANDED) ? 1 : 0)
    //
    //  uint32_t      stime_bb1;       //  0- 3
    //  // uint32_t     stime     :22  //        Start time:msec offset from hour
    //  // uint32_t     bPPkt1    :10; //        Whole bytes-per-packet (hi 10)
    //
    //  uint32_t      bb2_elapsed;     //  4- 7
    //  // uint32_t     bPPkt2    : 4; //        Whole bytes-per-packet (low 4)
    //  // uint32_t     bPPFrac   : 6; //        Fractional bytes-per-packet
    //  // uint32_t     elapsed   :22; //        Duration of flow in msec
    //
    //  uint32_t      pro_flg_pkts;    //  8-11
    //  // uint32_t     prot_flags: 8; //        is_tcp==0: IP protocol
    //                                 //        is_tcp==1 &&
    //                                 //          EXPANDED==0:TCPflags/All pkts
    //                                 //          EXPANDED==1:TCPflags/1st pkt
    //  // uint32_t     pflag     : 1; //        'pkts' requires multiplier?
    //  // uint32_t     is_tcp    : 1; //        1 if flow is TCP; 0 otherwise
    //  // uint32_t     padding   : 2; //
    //  // uint32_t     pkts      :20; //        Count of packets
    //
    //  uint8_t       tcp_state;       // 12     TCP state machine info
    //  uint8_t       rest_flags;      // 13     is_tcp==0: Flow's reported flags
    //                                 //        is_tcp==1 &&
    //                                 //          EXPANDED==0:Empty
    //                                 //          EXPANDED==1:TCPflags/!1st pkt
    //  uint16_t      application;     // 14-15  Type of traffic

    /** Decodes bit-bashed information providing the start time, duration, packet count, byte count,
      * protocol, TCP flags, TCP state, and application fields for SiLK record formats that use this
      * representation.
      */
    def decodeFlagsTimesVolumes(
      offset: Int,
      len: Int,
      inIsTCP: Boolean,
      header: Header
    ): (Instant, Duration, Long, Long, Protocol, TCPFlags, TCPFlags, TCPFlags, TCPState, Port) = {
      val (tcpState, inRestFlags, application) = {
        if (len == 12) {
          (TCPState(0), TCPFlags(0), Port(0))
        } else {
          (
            buffer.getTCPState(offset + 12),
            buffer.getTCPFlags(offset + 13),
            buffer.getApplication(offset + 14)
          )
        }
      }
      val protoFlagPackets = buffer.getInt(offset + 8)
      val inPackets = protoFlagPackets & 0x000fffff
      val pFlag = (protoFlagPackets & 0x00800000) != 0
      val isTCP =
        if (inIsTCP) {
          true
        } else {
          (protoFlagPackets & 0x00400000) != 0
        }
      val (protocol, flags, initFlags, restFlags) =
        if (!isTCP) {
          (Protocol((protoFlagPackets >> 24).toByte), inRestFlags, TCPFlags(0), TCPFlags(0))
        } else {
          val initFlags = TCPFlags((protoFlagPackets >> 24).toByte)
          val flags = TCPFlags((initFlags.toByte | inRestFlags.toByte).toByte)
          if (tcpState.expandedFlags) {
            (Protocol.TCP, flags, initFlags, inRestFlags)
          } else {
            (Protocol.TCP, flags, TCPFlags(0), TCPFlags(0))
          }
        }
      val bb2Elapsed = buffer.getInt(offset + 4)
      val elapsed = Duration.ofMillis((bb2Elapsed & 0x003fffff).toLong)
      val stimeBb1 = buffer.getInt(offset)
      val startTimeOffset = (stimeBb1 >> 10) & 0x003fffff
      val startTime = header.packedStartTime.plusMillis(startTimeOffset.toLong)
      val bpp = ((stimeBb1 & 0x000003ff) << 10) | ((bb2Elapsed >> 22) & 0x000003ff)
      val (bytes, packets) = decodeBytesPackets(bpp, inPackets, pFlag)
      (
        startTime, elapsed, packets, bytes, protocol, flags, initFlags, restFlags, tcpState,
        application
      )
    }

    // Decode the startTime, packets, bytes, elapsed, protocol, and flags fields for formats that
    // encode these using pkts_stime, bbe, and msec_flags:

    //  uint32_t      pkts_stime
    //  // uint32_t     pkts      :20; //        Count of packets
    //  // uint32_t     sTime     :12; //        Start time--offset from hour
    //
    //  uint32_t      bbe
    //  // uint32_t     bPPkt     :14; //        Whole bytes-per-packet
    //  // uint32_t     bPPFrac   : 6; //        Fractional bytes-per-packet
    //  // uint32_t     elapsed   :12; //        Duration of flow
    //
    //  uint32_t      msec_flags
    //  // uint32_t     sTime_msec:10; //        Fractional sTime (millisec)
    //  // uint32_t     elaps_msec:10; //        Fractional elapsed (millisec)
    //  // uint32_t     pflag     : 1; //        'pkts' requires multiplier?
    //  // uint32_t     is_tcp    : 1; //        1 if flow is TCP; 0 otherwise
    //  // uint32_t     padding   : 2; //        padding/reserved
    //  // uint32_t     prot_flags: 8; //        is_tcp==0: IP protocol
    //                                 //        is_tcp==1 &&
    //                                 //          EXPANDED==0:TCPflags/All pkts
    //                                 //          EXPANDED==1:TCPflags/1st pkt

    /** Decodes bit-bashed information providing the start time, packet count, byte count, duration,
      * protocol, and TCP flags fields for SiLK record formats that use this representation.
      */
    def getTimeBytesPacketsFlags(
      offset1: Int,
      offset2: Int,
      offset3: Int,
      forceTCP: Boolean,
      header: Header
    ): (Instant, Long, Long, Duration, Protocol, TCPFlags) = {
      val pktsStime = buffer.getInt(offset1)
      val bbe = buffer.getInt(offset2)
      val msecFlags = buffer.getInt(offset3)

      val inPackets = (pktsStime >> 12) & 0x000fffff

      val startTimeSecOffset = pktsStime & 0x00000fff
      val startTimeMilliOffset = (msecFlags >> 22) & 0x000003ff
      val startTime = header
        .packedStartTime
        .plusSeconds(startTimeSecOffset.toLong)
        .plusMillis(startTimeMilliOffset.toLong)

      val bytesPerPacket = (bbe >> 12) & 0x000fffff

      val elapsedSecs = bbe & 0x00000fff
      val elapsedMillis = (msecFlags >> 12) & 0x000003ff
      val elapsed = Duration.ofMillis((elapsedSecs * 1000 + elapsedMillis).toLong)

      val pflag = (msecFlags & 0x00000800) != 0
      val isTCP = ((msecFlags & 0x00000400) != 0) || forceTCP

      val (bytes, packets) = decodeBytesPackets(bytesPerPacket, inPackets, pflag)

      val protFlags = msecFlags.toByte

      if (isTCP) {
        (startTime, packets, bytes, elapsed, Protocol.TCP, TCPFlags(protFlags))
      } else {
        (startTime, packets, bytes, elapsed, Protocol(protFlags), TCPFlags(0))
      }
    }

    // Convert fixed-point bytes-per-packet and packets values to bytes and packet values for formats
    // requiring this
    def decodeBytesPackets(bpp: Int, inPackets: Int, pflag: Boolean): (Long, Long) = {
      val packets =
        if (pflag) {
          inPackets * 64
        } else {
          inPackets
        }
      val bytesPerPacket = (bpp >> 6) & 0x000003fff
      val bytesPerPacketFrac = bpp & 0x0000003f
      val i_quot = (bytesPerPacketFrac * packets) >> 6
      val i_rem = (bytesPerPacketFrac * packets) & 0x3f
      val bytes =
        (bytesPerPacket * packets) + i_quot +
          (if (i_rem >= 32) 1
           else 0)
      (bytes.toLong, packets.toLong)
    }

    // Figure out the protocol and TCP flags fields for formats that share the same bytes for
    // protocol/flags with a bit for whether it's TCP

    /** Works out what protocol and flag information is in use for some of the bit-bashed formats.
      */
    def decodeProtoFlags(
      isTCP: Boolean,
      protFlags: Byte,
      tcpState: TCPState,
      inRestFlags: Byte
    ): (Protocol, TCPFlags, TCPFlags, TCPFlags) = {
      if (isTCP) {
        if (tcpState.expandedFlags) {
          val flags = TCPFlags((protFlags | inRestFlags).toByte)
          val initFlags = TCPFlags(protFlags)
          val restFlags = TCPFlags(inRestFlags)
          (Protocol.TCP, flags, initFlags, restFlags)
        } else {
          (Protocol.TCP, TCPFlags(protFlags), TCPFlags(0), TCPFlags(0))
        }
      } else {
        (Protocol(protFlags), TCPFlags(inRestFlags), TCPFlags(0), TCPFlags(0))
      }
    }

    /** Works out the start time, duration, byte count, and packet count for some of the bit-bashed
      * formats.
      */
    def decodeSbbPef(
      sbbOffset: Int,
      pefOffset: Int,
      header: Header
    ): (Instant, Duration, Long, Long) = {
      val pef = buffer.getInt(pefOffset)
      val inPackets = (pef >> 12) & 0x000fffff
      val elapsedSecs = (pef >> 1) & 0x000007ff
      val pFlag = (pef & 0x00000001) != 0
      val sbb = buffer.getInt(sbbOffset)
      val bpp = sbb & 0x000fffff
      val startTimeSecOffset = (sbb >> 20) & 0x00000fff
      val startTime = header.packedStartTime.plusSeconds(startTimeSecOffset.toLong)
      val elapsed = Duration.ofSeconds(elapsedSecs.toLong)
      val (bytes, packets) = decodeBytesPackets(bpp, inPackets, pFlag)
      (startTime, elapsed, bytes, packets)
    }

    /** Decodes the bit-bashed "web port" information for SiLK web record formats that use a
      * condensed form for that information.
      */
    def decodeWWWPort(bits: Int): Port =
      bits match {
        case 0 => Port(80)
        case 1 => Port(443)
        case 2 => Port(8080)
        case _ => Port(0)
      }

    /** Updates bytes starting at the given offset with all the bytes in another array. */
    def putBytes(offset: Int, value: Array[Byte]): ByteBuffer = {
      buffer.position(offset)
      buffer.put(value)
    }

    /** Updates bytes starting at the given offset with no more then `length` bytes taken from
      * another array (stating at index 0 of that array).
      */
    def putBytes(offset: Int, value: Array[Byte], length: Int): ByteBuffer = {
      buffer.position(offset)
      buffer.put(value, 0, length)
    }

    /** Updates 8 bytes starting at the given offset with a 64-bit integer value in the Buffer's
      * byte order, representing the number of epoch milliseconds in the time `value`.
      */
    def putTimeMillis(offset: Int, value: Instant): ByteBuffer =
      buffer.putLong(offset, value.toEpochMilli())

    /** Updates 8 bytes starting at the given offset with a 64-bit integer value in the Buffer's
      * byte order, representing the number of epoch nanoseconds in the time `value`.
      */
    def putTimeNanos(offset: Int, value: Instant): ByteBuffer = {
      assert(
        value.getEpochSecond() < Long.MaxValue / 1000000000,
        s"Time $value is too large to be represented as a 64-bit signed number of nanoseconds"
      )
      buffer.putLong(offset, value.getEpochSecond() * 1000000000 + value.getNano())
    }

    /** Updates 4 bytes starting at the given offset with a 32-bit integer value, written in network
      * byte order, representing a duration expressed in milliseconds.
      */
    def putElapsedMillis(offset: Int, value: Duration): ByteBuffer =
      buffer.putInt(offset, value.toMillis().toInt)

    /** Updates 2 bytes starting at the given offset with a 16-bit integer value, written in network
      * byte order, representing a port number.
      */
    def putPort(offset: Int, value: Port): ByteBuffer = buffer.putShort(offset, value.toShort)

    /** Updates 1 byte at the given offset with an 8-bit integer protocol value. */
    def putProtocol(offset: Int, value: Protocol): ByteBuffer = buffer.put(offset, value.toByte)

    /** Updates 1 byte at the given offset with an 8-bit integer SiLK flow type value. */
    def putFlowType(offset: Int, value: FlowType): ByteBuffer = buffer.put(offset, value.toByte)

    /** Updates 2 bytes starting at the given offset with a 16-bit integer SiLK sensor value written
      * in network byte order.
      */
    def putSensor16(offset: Int, value: Sensor): ByteBuffer = buffer.putShort(offset, value.toShort)

    /** Updates 1 byte at the given offset with an 8-bit TCP flag value. */
    def putTCPFlags(offset: Int, value: TCPFlags): ByteBuffer = buffer.put(offset, value.toByte)

    /** Updates 1 byte at the given offset with an 8-bit SiLK TCP state (attribute) value. */
    def putTCPState(offset: Int, value: TCPState): ByteBuffer = buffer.put(offset, value.toByte)

    /** Updates 2 bytes starting at the given offset with a 16-bit integer port value, written in
      * network byte order, representing an application.
      */
    def putApplication(offset: Int, value: Port): ByteBuffer =
      buffer.putShort(offset, value.toShort)

    /** Updates 2 bytes starting at the given offset with a 16-bit integer value, written in network
      * byte order, representing a memo value.
      */
    def putMemo(offset: Int, value: Short): ByteBuffer = buffer.putShort(offset, value)

    /** Updates 2 bytes starting at the given offset with a 16-bit integer value, written in network
      * byte order, representing the lower two bytes of an SNMP interface value.
      */
    def putSNMPInterface16(offset: Int, value: SNMPInterface): ByteBuffer =
      buffer.putShort(offset, value.toShort)

    /** Updates 4 bytes starting at the given offset with a 32-bit integer value, written in network
      * byte order, representing an SNMP interface value.
      */
    def putSNMPInterface32(offset: Int, value: SNMPInterface): ByteBuffer =
      buffer.putInt(offset, value.toInt)

    /** Updates 4 bytes starting at the given offset with a 32-bit integer value, written in network
      * byte order, representing a packet count.
      */
    def putPacketCount32(offset: Int, value: Long): ByteBuffer = buffer.putInt(offset, value.toInt)

    /** Updates 8 bytes starting at the given offset with a 64-bit integer value, written in network
      * byte order, representing a packet count.
      */
    def putPacketCount64(offset: Int, value: Long): ByteBuffer = buffer.putLong(offset, value)

    /** Updates 4 bytes starting at the given offset with a 32-bit integer value, written in network
      * byte order, representing a byte count.
      */
    def putByteCount32(offset: Int, value: Long): ByteBuffer = buffer.putInt(offset, value.toInt)

    /** Updates 8 bytes starting at the given offset with a 64-bit integer value, written in network
      * byte order, representing a byte count.
      */
    def putByteCount64(offset: Int, value: Long): ByteBuffer = buffer.putLong(offset, value)

    /** How to fill in the first 12 bytes of a 16 byte value when encoding an IPv4 address in an
      * array that expects IPv6 addresses.
      */
    private val ipv4InV6Prefix: Array[Byte] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1)

    /** Updates 16 bytes starting at the given offset with an IP address. When `value` is an IPv6
      * address, the address is copied to the bytes. When `value` in an IPv4 address, an IPv6
      * encoding of the IPv4 address is copied.
      */
    def putIPAddress(offset: Int, value: IPAddress): ByteBuffer = {
      value match {
        case ipv6: IPv6Address => {
          buffer.putBytes(offset, ipv6.toBytes, 16)
        }
        case ipv4: IPv4Address => {
          buffer.putBytes(offset, ipv4InV6Prefix).putBytes(offset + 12, ipv4.toBytes)
        }
      }
    }

    /** Updates 16 bytes starting at the given offset with a 128-bit IPv6 address. */
    def putIPv6Address(offset: Int, value: IPv6Address): ByteBuffer =
      buffer.putBytes(offset, value.toBytes, 16)

    /** Updates 4 bytes starting at the given offset with a 32-bit IPv4 address written in network
      * byte order.
      */
    def putIPv4Address(offset: Int, value: IPv4Address): ByteBuffer =
      buffer.putInt(offset, value.toInt)

  }

  /** Provides a helper method to RWRec objects for some formats which need a sanity check to see if
    * initFlags and restFlags are set despite the protocol not being TCP or the expandedFlags bit
    * not being set. This runs that check and generates a new RWRec with the fix if needed.
    */
  implicit class TCPStateFixHelper(val self: RWRec) {
    def maybeClearTCPStateExpanded: RWRec = {
      if (
        self.tcpState.expandedFlags &&
        (self.protocol != Protocol.TCP ||
        (self.initFlags.toByte == 0 && self.restFlags.toByte == 0))
      ) {
        RWRec(
          self.startTime,
          self.elapsed,
          self.sPort,
          self.dPort,
          self.protocol,
          self.flowType,
          self.sensor,
          self.flags,
          TCPFlags(0),
          TCPFlags(0),
          TCPState((self.tcpState.toByte & ~0x01).toByte),
          self.application,
          self.memo,
          self.input,
          self.output,
          self.packets,
          self.bytes,
          self.sIP,
          self.dIP,
          self.nhIP
        )
      } else {
        self
      }
    }
  }

}

// @LICENSE_FOOTER@
//
// Mothra 1.7
//
// Copyright 2025 Carnegie Mellon University.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS
// FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND,
// EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS
// FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL.
// CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM
// PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Licensed under a GNU GPL 2.0-style license, please see LICENSE.txt or contac
// permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited
// distribution.  Please see Copyright notice for non-US Government use and distribution.
//
// This Software includes and/or makes use of Third-Party Software each subject to its own license.
//
// DM24-1649
