// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.ipfix
package datatype

import util.ListView

import scala.math.BigInt

import java.lang.{String => JString}
import java.nio.ByteBuffer
import java.time.Instant


/**
  * The AbstractDataType and its subclasses provide an implementation
  * of each possible [[DataTypes]] value, where those values represent
  * an IPFIX Information Element Data Types as defined in RFC5102.
  *
  * @param id A reference to the DataTypes value that this class implements.
  * @param name The name of the DataType.
  * @param defaultLength The default length of the DataType.
  * @param minimumLength The minimim length of the DataType.
  * @param maximumLength The maximim length of the DataType.
  * @param minimumValue An optional minimum value for the DataType.
  * @param maximumValue An optional maximum value for the DataType.
  */
abstract class AbstractDataType(
  final val id: DataTypes,
  final val name: JString,
  final val defaultLength: Int,
  final protected val minimumLength: Int,
  final protected val maximumLength: Int,
  final val minimumValue: Option[Long] = None,
  final val maximumValue: Option[Long] = None)
    extends DataType
{
  // auxiliary constructors
  def this(id: DataTypes, name: JString) =
    this(id, name, VARLEN, 0, VARLEN)

  def this(id: DataTypes, name: JString, defLen: Int) =
    this(id, name, defLen, defLen, defLen)

  def this(id: DataTypes, name: JString, defLen: Int, minLen: Int, maxLen: Int,
    minVal:Long, maxVal: Long) =
    this(id, name, defLen, minLen, maxLen, Option(minVal), Option(maxVal))


  require(!name.isEmpty, "Data type name may not be empty")

  require(defaultLength >= 0 && defaultLength <= VARLEN,
    s"Data type defaultLength $defaultLength is out of range (0--65535)")

  require(minimumLength >= 0 && minimumLength <= VARLEN,
    s"Data type minimumLength $minimumLength is out of range (0--65535)")

  require(maximumLength >= 0 && maximumLength <= VARLEN,
    s"Data type maximumLength $maximumLength is out of range (0--65535)")

  require(minimumLength <= maximumLength,
    s"Data type minimumLength $minimumLength is greater than" +
      s" maximumLength $maximumLength")

  require(defaultLength >= minimumLength,
    s"Data type defaultLength $defaultLength is less than" +
      s" minimumLength $minimumLength")

  require(defaultLength <= maximumLength,
    s"Data type defaultLength $defaultLength is greater than" +
      s" maximumLength $maximumLength")

  require(minimumValue.isDefined == maximumValue.isDefined, {
    if ( minimumValue.isDefined ) {
      "Data type minimumValue is specified but maximumValue is not"
    } else {
      "Data type maximumValue is specified but minimumValue is not"
    }
  })

  for (minval <- minimumValue ; maxval <- maximumValue) {
    if ( id.value <= 4 ) {
      require(!unsignedLessThan(maxval, minval),
        s"Data type minimumValue $minval is greater than maximumValue $maxval")
    } else if ( id.value > 4 ) {
      require(maxval >= minval,
        s"Data type minimumValue $minval is greater than maximumValue $maxval")
    }
  }


  /** Return true if a is less than b where a and b should be treated as
    * unsigned numbers. */
  private[this] def unsignedLessThan(a: Long, b: Long): Boolean =
    ((a < b) ^ (a < 0) ^ (b < 0))


  // the UNIX epoch as a number of seconds since NTP epoch
  // (1900-01-01)
  private[this] final val JAN_1970: Long = 2208988800L

  protected final def decodeNTP(ntp: Long, isMicro: scala.Boolean): Instant = {
    /* When decoding the fractional seconds, we add (1 << 31) to the value
     * before shifting right by 32 in order to "round up" the
     * fractional seconds.  Without this, we found the decoded value
     * was 1 less than the encoded value. */
    if (isMicro) {
      // compute microseconds then multiply by 1000 to ensure the
      // nanosecond value in Instant is a multiple of 1000
      Instant.ofEpochSecond(((ntp >>> 32) - JAN_1970),
        ((((ntp & 0xFFFFF800L) * 1000000L) + 0x80000000L) >>> 32) * 1000L)
    } else {
      Instant.ofEpochSecond(((ntp >>> 32) - JAN_1970),
        ((((ntp & 0xFFFFFFFFL) * 1000000000L) + 0x80000000L) >>> 32))
    }
  }

  protected final def encodeNTP(t: Instant, isMicro: scala.Boolean): Long = {
    if (isMicro) {
      (((t.getEpochSecond + JAN_1970) << 32) |
        ((((t.getNano.toLong / 1000L) << 32) / 1000000L) & 0xFFFFF800L))
    } else {
      (((t.getEpochSecond + JAN_1970) << 32) |
        (((t.getNano.toLong << 32) / 1000000000L) & 0xFFFFFFFFL))
    }
  }


  protected final def encodeNumber(
    b: ByteBuffer, len: Int, signed: scala.Boolean, obj: Any): ByteBuffer =
  {
    val v: Long =
      if (signed) {
        obj match {
          case x: Byte => x.toLong
          case x: Short => x.toLong
          case x: Int => x.toLong
          case x: Long => x
          case _ => throw new IllegalFieldSpecifierException
        }
      } else {
        obj match {
          case x: Byte => 0xffL & x
          case x: Short => 0xffffL & x
          case x: Int => 0xffffffffL & x
          case x: Long => x
          case _ => throw new IllegalFieldSpecifierException
        }
      }

    if ((signed && (v < minimumValue.get || v > maximumValue.get))
      || (!signed && (unsignedLessThan(v, minimumValue.get) ||
        unsignedLessThan(maximumValue.get, v))))
    {
      throw new IllegalFieldSpecifierException
    }

    len match {
      case 1 => b.put(v.toByte)
      case 2 => b.putShort(v.toShort)
      case 4 => b.putInt(v.toInt)
      case 8 => b.putLong(v)
      case _ =>
        val bigint = BigInt(v)
        val arr = bigint.toByteArray
        val fill: Byte = (if (bigint.signum < 0) { 0xff } else { 0 }).toByte
        for (_ <- arr.length until len) {
          b.put(fill)
        }
        b.put(arr)
    }
  }

  protected final def checkTypeNumber(obj: Any): Boolean =
    obj match {
      case (_ : Byte | _ : Short | _ : Int | _ : Long) => true
      case _ => false
    }

  protected def isValidLength(length: Int): scala.Boolean =
    length >= minimumLength && length <= maximumLength

  /**
    * Converts the bytes in the buffer to an object.  This is a helper
    * function for getValue().
    */
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Any
  // abstract method

  /**
    * Converts the object into a stream of bytes and appends to the
    * buffer.  This is a helper function for toBuffer().
    */
  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer
  // abstract method

  /**
    * Computes the length of an object.  This is a helper function for
    * octetLength(), and it is only called when `len` is VARLEN.
    * Objects that support VARLEN values must override this method.
    */
  protected def getLength(obj: Any, len: Int): Int = len

  // implements DataType.getValue()
  final def getValue(b: ByteBuffer, s: Session, ie: InfoElement): Any = {
    b.rewind
    decode(b, s, ie)
  }

  // implements DataType.toBuffer()
  final def toBuffer(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    if (!isValidLength(len)) {
      // FIXME: Need different error
      throw new TruncatedReadException(name, minimumLength, len)
    }
    encode(b, s, len, obj)
  }

  // implements DataType.octetLength()
  final def octetLength(obj: Any, len: Int): Int =
    if ( !checkType(obj) ) {
      VARLEN
    } else if ( len != VARLEN ) {
      len
    } else {
      getLength(obj, len)
    }
}


/**
  * Implements encoding and decoding of the basicList DataType,
  * [[DataTypes.BasicList]].
  */
object BasicList extends AbstractDataType(
  DataTypes.BasicList, "basicList")
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): BasicList =
    if ( !isValidLength(b.limit()) ) {
      throw new TruncatedReadException(name, minimumLength, b.limit())
    } else {
      new org.cert.netsa.io.ipfix.CollectedBasicList(b, s)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    val bl = obj match {
      case x: org.cert.netsa.io.ipfix.BasicList => x
      case _ => throw new IllegalFieldSpecifierException
    }
    if (len == VARLEN) {
      val pos = b.position()
      b.put(0xff.toByte)
      b.putShort(0.toShort)
      bl.toBuffer(b, s)
      val sz = b.position() - pos - 3
      assert(sz >= 0 && sz < 0xFFFF)
      b.putShort(pos + 1, sz.toShort)
    } else {
      val pos = b.position()
      bl.toBuffer(b, s)
      val sz = b.position() - pos
      if (sz > len) {
        // FIXME ERROR CODE
        throw new IllegalFieldSpecifierException
      }
      // FIXME: also error if too short?
      for ( _ <- sz until len ) {
        b.put(0.toByte)
      }
      b
    }
  }

  def checkType(obj: Any): Boolean =
    obj match {
      case _: org.cert.netsa.io.ipfix.BasicList => true
      case _ => false
    }

  protected override def getLength(obj: Any, len: Int): Int =
    obj match {
      case bl: org.cert.netsa.io.ipfix.BasicList => 3 + bl.octetLength
      case _ => throw new RuntimeException("Programmer error")
    }
}


/**
  * Implements encoding and decoding of the boolean DataType,
  * [[DataTypes.Boolean]].
  */
object Boolean extends AbstractDataType(
  DataTypes.Boolean, "boolean", 1)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Boolean =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      b.get() match {
        case 1 => true
        case 2 => false
        case _ => throw new IllegalFieldSpecifierException
      }
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: scala.Boolean => b.put( (if (x) { 1 } else { 0 }).toByte )
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case scala.Boolean => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the dateTimeMicroseconds DataType,
  * [[DataTypes.DateTimeMicroseconds]].
  */
object DateTimeMicroseconds extends AbstractDataType(
  DataTypes.DateTimeMicroseconds, "dateTimeMicroseconds", 8)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Instant =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      decodeNTP(b.getLong(), true)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Instant => b.putLong(encodeNTP(x, true))
      case x: Long => b.putLong(x)  // assume already encoded
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Instant | _ : Long) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the dateTimeMilliseconds DataType,
  * [[DataTypes.DateTimeMilliseconds]].
  */
object DateTimeMilliseconds extends AbstractDataType(
  DataTypes.DateTimeMilliseconds, "dateTimeMilliseconds", 8)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Instant =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      Instant.ofEpochMilli(b.getLong())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Instant => b.putLong(x.toEpochMilli())
      case x: Long => b.putLong(x)  // assume already encoded
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Instant | _ : Long) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the dateTimeNanoseconds DataType,
  * [[DataTypes.DateTimeNanoseconds]].
  */
object DateTimeNanoseconds extends AbstractDataType(
  DataTypes.DateTimeNanoseconds, "dateTimeNanoseconds", 8)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Instant =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      decodeNTP(b.getLong(), false)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Instant => b.putLong(encodeNTP(x, false))
      case x: Long => b.putLong(x)  // assume already encoded
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Instant | _ : Long) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the dateTimeSeconds DataType,
  * [[DataTypes.DateTimeSeconds]].
  */
object DateTimeSeconds extends AbstractDataType(
  DataTypes.DateTimeSeconds, "dateTimeSeconds", 4)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Instant =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      Instant.ofEpochSecond(0xffffffffL & b.getInt())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Instant => b.putInt(x.getEpochSecond().toInt)
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case _ : Instant => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the float32 DataType,
  * [[DataTypes.Float32]].
  */
object Float32 extends AbstractDataType(
  DataTypes.Float32, "float32", 4)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Float =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      return b.getFloat()
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Float => b.putFloat(x)
      case x: Double => b.putFloat(x.toFloat)
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Float | _ : Double) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the float64 DataType,
  * [[DataTypes.Float64]].
  */
object Float64 extends AbstractDataType(
  DataTypes.Float64, "float64", 8, 4, 8)
{
  protected override def isValidLength(length: Int): scala.Boolean =
    length == 8 || length == 4

  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Double =
    b.limit() match {
      case 4 => b.getFloat().toDouble
      case 8 => b.getDouble()
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    if (len == 4) {
      obj match {
        case x: Double => b.putFloat(x.toFloat)
        case x: Float => b.putFloat(x)
        case _ => throw new IllegalFieldSpecifierException
      }
    } else {
      obj match {
        case x: Double => b.putDouble(x)
        case x: Float => b.putDouble(x.toDouble)
        case _ => throw new IllegalFieldSpecifierException
      }
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Float | _ : Double) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the ipv4Address DataType,
  * [[DataTypes.IPv4Address]].
  */
object IPv4Address extends AbstractDataType(
  DataTypes.IPv4Address, "ipv4Address", 4)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement):
      org.cert.netsa.data.net.IPv4Address =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      new org.cert.netsa.data.net.IPv4Address(b.getInt())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: org.cert.netsa.data.net.IPv4Address => b.put(x.toBytes)
      case x: Int => b.putInt(x)  // assume already encoded
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : org.cert.netsa.data.net.IPv4Address | _ : Int) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the ipv6Address DataType,
  * [[DataTypes.IPv6Address]].
  */
object IPv6Address extends AbstractDataType(
  DataTypes.IPv6Address, "ipv6Address", 16)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement):
      org.cert.netsa.data.net.IPv6Address =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      val ba = new Array[Byte](16)
      b.get(ba)
      org.cert.netsa.data.net.IPv6Address(ba)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: org.cert.netsa.data.net.IPv6Address => b.put(x.toBytes)
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case _: org.cert.netsa.data.net.IPv6Address => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the macAddress DataType,
  * [[DataTypes.MacAddress]].
  */
object MacAddress extends AbstractDataType(
  DataTypes.MacAddress, "macAddress", 6)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): ListView =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      new ListView(b)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    obj match {
      case x: Array[Byte] =>
        if (x.length != minimumLength) {
          throw new IllegalFieldSpecifierException
        }
        b.put(x)
      case x: ListView =>
        if (x.length != minimumLength) {
          throw new IllegalFieldSpecifierException
        }
        b.put(x.toArray)
      case _ => throw new IllegalFieldSpecifierException
    }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Array[Byte] | _ : ListView) => true
      case _ => false
    }
}


/**
  * Implements encoding and decoding of the octetArray DataType,
  * [[DataTypes.OctetArray]].
  */
object OctetArray extends AbstractDataType(
  DataTypes.OctetArray, "octetArray")
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): ListView =
    if ( !isValidLength(b.limit()) ) {
      throw new TruncatedReadException(name, minimumLength, b.limit())
    } else {
      new ListView(b)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    val arr: Array[Byte] = obj match {
      case x: Array[Byte] => x
      case x: ListView => x.toArray
      case _ => throw new IllegalFieldSpecifierException
    }
    val sz = arr.length
    if (len == VARLEN) {
      if (sz < 0xff) {
        b.put(sz.toByte)
          .put(arr)
      } else {
        b.put(0xff.toByte)
          .putShort(sz.toShort)
          .put(arr)
      }
    } else {
      // fixed size, truncate or expand as needed
      if (sz > len) {
        b.put(arr, b.position(), len)
      } else {
        b.put(arr)
        for ( _ <- sz until len ) {
          b.put(0.toByte)
        }
        b
      }
    }
  }

  def checkType(obj: Any): Boolean =
    obj match {
      case (_ : Array[Byte] | _ : ListView) => true
      case _ => false
    }

  protected override def getLength(obj: Any, len: Int): Int = {
    val sz = obj match {
      case x: Array[Byte] => x.length
      case x: ListView => x.length
      case _ => throw new RuntimeException("Programmer error")
    }
    if (sz < 0xff) {
      sz + 1
    } else {
      sz + 3
    }
  }
}


/**
  * Implements encoding and decoding of the signed16 DataType,
  * [[DataTypes.Signed16]].
  */
object Signed16 extends AbstractDataType(
  DataTypes.Signed16, "signed16", 2, 1, 2,
  Short.MinValue.toLong, Short.MaxValue.toLong)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Short =
    b.limit() match {
      case 1 => b.get().toShort
      case 2 => b.getShort()
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, true, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the signed32 DataType,
  * [[DataTypes.Signed32]].
  */
object Signed32 extends AbstractDataType(
  DataTypes.Signed32, "signed32", 4, 1, 4,
  Int.MinValue.toLong, Int.MaxValue.toLong)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Int =
    b.limit() match {
      case 1 => b.get().toInt
      case 2 => b.getShort().toInt
      case 4 => b.getInt()
      case 3 =>
        val ba = new Array[Byte](3)
        b.get(ba)
        BigInt(ba).toInt
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, true, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the signed64 DataType,
  * [[DataTypes.Signed64]].
  */
object Signed64 extends AbstractDataType(
  DataTypes.Signed64, "signed64", 8, 1, 8, Long.MinValue, Long.MaxValue)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Long =
    b.limit() match {
      case 1 => b.get().toLong
      case 2 => b.getShort().toLong
      case 4 => b.getInt().toLong
      case 8 => b.getLong()
      case 3 | 5 | 6 | 7 =>
        val ba = new Array[Byte](b.limit())
        b.get(ba)
        BigInt(ba).toLong
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, true, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the signed8 DataType,
  * [[DataTypes.Signed8]].
  */
object Signed8 extends AbstractDataType(
  DataTypes.Signed8, "signed8", 1, 1, 1,
  Byte.MinValue.toLong, Byte.MaxValue.toLong)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Byte =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      return b.get()
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, true, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the string DataType,
  * [[DataTypes.String]].
  */
object String extends AbstractDataType(DataTypes.String, "string")
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): JString =
    if ( !isValidLength(b.limit()) ) {
      throw new TruncatedReadException(name, minimumLength, b.limit())
    } else {
      val ba = new Array[Byte](b.limit())
      b.get(ba)
      new JString(ba, "UTF-8")
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    val arr: Array[Byte] = obj match {
      case x: JString => x.getBytes("UTF-8")
      case _ => throw new IllegalFieldSpecifierException
    }
    val sz = arr.length
    if (len == VARLEN) {
      if (sz < 0xff) {
        b.put(sz.toByte)
          .put(arr)
      } else {
        b.put(0xff.toByte)
          .putShort(sz.toShort)
          .put(arr)
      }
    } else {
      // fixed size, truncate or expand as needed
      if (sz > len) {
        b.put(arr, b.position(), len)
      } else {
        b.put(arr)
        for ( _ <- sz until len ) {
          b.put(0.toByte)
        }
        b
      }
    }
  }

  def checkType(obj: Any): Boolean =
    obj match {
      case _: JString => true
      case _ => false
    }

  protected override def getLength(obj: Any, len: Int): Int =
    obj match {
      case x: JString =>
        val sz = x.length
        if (sz < 0xff) {
          sz + 1
        } else {
          sz + 3
        }
      case _ =>
        throw new RuntimeException("Programmer error")
    }
}


/**
  * Implements encoding and decoding of the subTemplateList DataType,
  * [[DataTypes.SubTemplateList]].
  */
object SubTemplateList extends AbstractDataType(
  DataTypes.SubTemplateList, "subTemplateList")
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement):
      SubTemplateList =
    if ( !isValidLength(b.limit()) ) {
      throw new TruncatedReadException(name, minimumLength, b.limit())
    } else {
      new org.cert.netsa.io.ipfix.CollectedSubTemplateList(b, s)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    val stl = obj match {
      case x: org.cert.netsa.io.ipfix.SubTemplateList => x
      case _ => throw new IllegalFieldSpecifierException
    }
    if (len == VARLEN) {
      val pos = b.position()
      b.put(0xff.toByte)
      b.putShort(0.toShort)
      stl.toBuffer(b, s)
      val sz = b.position() - pos - 3
      assert(sz >= 0 && sz < 0xFFFF)
      b.putShort(pos + 1, sz.toShort)
    } else {
      val pos = b.position()
      stl.toBuffer(b, s)
      val sz = b.position() - pos
      if (sz > len) {
        // FIXME ERROR CODE
        throw new IllegalFieldSpecifierException
      }
      // FIXME: also error if too short?
      for ( _ <- sz until len ) {
        b.put(0.toByte)
      }
      b
    }
  }

  def checkType(obj: Any): Boolean =
    obj match {
      case _: org.cert.netsa.io.ipfix.SubTemplateList => true
      case _ => false
    }

  protected override def getLength(obj: Any, len: Int): Int =
    obj match {
      case stl: org.cert.netsa.io.ipfix.SubTemplateList => 3 + stl.octetLength
      case _ => throw new RuntimeException("Programmer error")
    }
}


/**
  * Implements encoding and decoding of the subTemplateMultiList DataType,
  * [[DataTypes.SubTemplateMultiList]].
  */
object SubTemplateMultiList extends AbstractDataType(
  DataTypes.SubTemplateMultiList, "subTemplateMultiList")
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement):
      SubTemplateMultiList =
    if ( !isValidLength(b.limit()) ) {
      throw new TruncatedReadException(name, minimumLength, b.limit())
    } else {
      new org.cert.netsa.io.ipfix.CollectedSubTemplateMultiList(b, s)
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer = {
    val stml = obj match {
      case x: org.cert.netsa.io.ipfix.SubTemplateMultiList => x
      case _ => throw new IllegalFieldSpecifierException
    }
    if (len == VARLEN) {
      val pos = b.position()
      b.put(0xff.toByte)
      b.putShort(0.toShort)
      stml.toBuffer(b, s)
      val sz = b.position() - pos - 3
      assert(sz >= 0 && sz < 0xFFFF)
      b.putShort(pos + 1, sz.toShort)
    } else {
      val pos = b.position()
      stml.toBuffer(b, s)
      val sz = b.position() - pos
      if (sz > len) {
        // FIXME ERROR CODE
        throw new IllegalFieldSpecifierException
      }
      // FIXME: also error if too short?
      for ( _ <- sz until len ) {
        b.put(0.toByte)
      }
      b
    }
  }

  def checkType(obj: Any): Boolean =
    obj match {
      case _: org.cert.netsa.io.ipfix.SubTemplateMultiList => true
      case _ => false
    }

  protected override def getLength(obj: Any, len: Int): Int =
    obj match {
      case stml: org.cert.netsa.io.ipfix.SubTemplateMultiList =>
        3 + stml.octetLength
      case _ =>
        throw new RuntimeException("Programmer error")
    }
}


/**
  * Implements encoding and decoding of the unsigned16 DataType,
  * [[DataTypes.Unsigned16]].
  */
object Unsigned16 extends AbstractDataType(
  DataTypes.Unsigned16, "unsigned16", 2, 1, 2, 0L, 0xffffL)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Int =
    b.limit() match {
      case 1 => 0xff & b.get()
      case 2 => 0xffff & b.getShort()
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, false, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the unsigned32 DataType,
  * [[DataTypes.Unsigned32]].
  */
object Unsigned32 extends AbstractDataType(
  DataTypes.Unsigned32, "unsigned32", 4, 1, 4, 0L, 0xffffffffL)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Long =
    b.limit() match {
      case 1 => 0xffL & b.get()
      case 2 => 0xffffL & b.getShort()
      case 4 => 0xffffffffL & b.getInt()
      case 3 =>
        val ba = new Array[Byte](3)
        b.get(ba)
        BigInt(0, ba).toLong
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, false, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the unsigned64 DataType,
  * [[DataTypes.Unsigned64]].
  */
object Unsigned64 extends AbstractDataType(
  DataTypes.Unsigned64, "unsigned64", 8, 1, 8, 0L, 0xffffffffffffffffL)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Long =
    b.limit() match {
      case 1 => 0xffL & b.get()
      case 2 => 0xffffL & b.getShort()
      case 4 => 0xffffffffL & b.getInt()
      case 8 => b.getLong()
      case 3 | 5 | 6 | 7 =>
        val ba = new Array[Byte](b.limit())
        b.get(ba)
        BigInt(0, ba).toLong
      case _ =>
        throw new TruncatedReadException(name, minimumLength, b.limit())
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, false, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}


/**
  * Implements encoding and decoding of the unsigned8 DataType,
  * [[DataTypes.Unsigned8]].
  */
object Unsigned8 extends AbstractDataType(
  DataTypes.Unsigned8, "unsigned8", 1, 1, 1, 0L, 0xffL)
{
  protected def decode(b: ByteBuffer, s: Session, ie: InfoElement): Short =
    if ( b.limit() != defaultLength ) {
      throw new TruncatedReadException(name, defaultLength, b.limit())
    } else {
      (0xff & b.get()).toShort
    }

  protected def encode(b: ByteBuffer, s: Session, len: Int, obj: Any): ByteBuffer =
    encodeNumber(b, len, false, obj)

  def checkType(obj: Any): Boolean = checkTypeNumber(obj)
}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
