// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.ipfix.testing

import java.time.Instant
import org.cert.netsa.data.net.IPAddress

/** A value with a format encoded by octets in the IPFIX protocol. */
trait Encoded {

  /** The bytes in the encoded form of the value. */
  def encodedBytes: Iterator[Byte]
  /** The length in bytes of the encoded form of the value. */
  def encodedLength: Int

  // Helper methods for implementations to use to encode basic values

  /** 1-byte encoding for a Boolean value. Yes, 1 = true, 2 = false. */
  final protected def b1(v: Boolean): Iterator[Byte] =
    Iterator(if (v) 1 else 2)
  /** 1-byte encoding for an integral value. */
  final protected def b1(v: Long): Iterator[Byte] =
    Iterator(v.toByte)

  /** 2-byte encoding for an itegral value. */
  final protected def b2(v: Long): Iterator[Byte] =
    Iterator(v >> 8, v).map(_.toByte)

  /** 4-byte encoding for an integral value. */
  final protected def b4(v: Long): Iterator[Byte] =
    Iterator(v >> 24, v >> 16, v >> 8, v).map(_.toByte)
  /** 4-byte encoding for a floating-point value. */
  final protected def b4(v: Float): Iterator[Byte] =
    b4(java.lang.Float.floatToIntBits(v))

  /** 8-byte encoding for an integral value. */
  final protected def b8(v: Long): Iterator[Byte] =
    Iterator(v>>56, v>>48, v>>40, v>>32, v>>24, v>>16, v>>8, v).map(_.toByte)
  /** 8-byte encoding for a floating-point value. */
  final protected def b8(v: Double): Iterator[Byte] =
    b8(java.lang.Double.doubleToLongBits(v))

}

/** Mostly, implicit conversions from values to their encoded
  * representations. */
object Encoded {

  // Basic types

  implicit class EncodedBoolean(val b: Boolean) extends Encoded {
    def encodedBytes = Iterator(if (b) 1 else 2) // yes, 1 = T, 2 = F
    def encodedLength = 1
  }
  implicit class EncodedByte(val b: Byte) extends Encoded {
    def encodedBytes = b1(b)
    def encodedLength = 1
  }
  implicit class EncodedShort(val s: Short) extends Encoded {
    def encodedBytes = b2(s)
    def encodedLength = 2
  }
  implicit class EncodedInt(val i: Int) extends Encoded {
    def encodedBytes = b4(i)
    def encodedLength = 4
  }
  implicit class EncodedLong(val l: Long) extends Encoded {
    def encodedBytes = b8(l)
    def encodedLength = 8
  }

  implicit class EncodedFloat(val f: Float) extends Encoded {
    def encodedBytes = b4(java.lang.Float.floatToIntBits(f))
    def encodedLength = 4
  }

  implicit class EncodedDouble(val d: Double) extends Encoded {
    def encodedBytes = b8(java.lang.Double.doubleToLongBits(d))
    def encodedLength = 8
  }

  // Variable-length types. You probably want to wrap these in
  // Encoded.varLen or Encoded.padLen.

  implicit class EncodedString(val s: String) extends Encoded {
    def encodedBytes = s.getBytes("utf-8").iterator
    def encodedLength = encodedBytes.length
  }

  implicit class EncodedByteArray(val a: Array[Byte]) extends Encoded {
    def encodedBytes = a.iterator
    def encodedLength = a.length
  }

  implicit class EncodedIPAddress(val a: IPAddress) extends Encoded {
    def encodedBytes = a.toBytes.iterator
    def encodedLength = encodedBytes.length
  }

  /** Variable-length encoding of an encoded value. Prefixes the encoded
    * version with an encoding if its length. */
  case class varLen(e: Encoded) extends Encoded {
    def encodedBytes =
      if ( e.encodedLength >= 255 )
        b1(0xff) ++ b2(e.encodedLength) ++ e.encodedBytes
      else
        b1(e.encodedLength) ++ e.encodedBytes
    def encodedLength =
      if ( e.encodedLength >= 255 )
        e.encodedLength + 3
      else
        e.encodedLength + 1
  }

  /** Fixed-length encoding of an encoded value, padded to a specific
    * length. Adds zeroed padding octets to the end of the value if
    * necessary, fails if the value is too large. */
  case class padLen(e: Encoded, override val encodedLength: Int)
      extends Encoded
  {
    require(e.encodedLength <= encodedLength,
      s"Encoded value $e larger than target size $encodedLength")
    def encodedBytes =
      (e.encodedBytes ++ Iterator.continually[Byte](0)).take(encodedLength)
  }

  // Time types. There are a variety of ways to encode time, so choose
  // which you want with Encoded.seconds(...) or the like.

  case class seconds(i: Instant) extends Encoded {
    def encodedBytes = b4(i.getEpochSecond)
    def encodedLength = 4
  }
  object seconds {
    def apply(d: String): seconds = seconds(parseTime(d))
  }

  case class milliseconds(i: Instant) extends Encoded {
    def encodedBytes = b8(i.toEpochMilli())
    def encodedLength = 8
  }
  object milliseconds {
    def apply(d: String): milliseconds = milliseconds(parseTime(d))
  }

  case class microseconds(i: Instant) extends Encoded {
    def encodedBytes = b8(ntpEncodedMicro(i))
    def encodedLength = 8
  }
  object microseconds {
    def apply(d: String): microseconds = microseconds(parseTime(d))
  }

  case class nanoseconds(i: Instant) extends Encoded {
    def encodedBytes = b8(ntpEncodedNano(i))
    def encodedLength = 8
  }
  object nanoseconds {
    def apply(d: String): nanoseconds = nanoseconds(parseTime(d))
  }

  // Helpers for the time types.

  private def makeDateTimeFormatter(s: String) = {
    import java.time.ZoneOffset
    import java.time.format.DateTimeFormatterBuilder
    import java.time.temporal.ChronoField
    new DateTimeFormatterBuilder().
      appendPattern(s).
      appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true).
      toFormatter().
      withZone(ZoneOffset.UTC)
  }

  private val timeFormats = Array(
    "yyyy-M-d'T'H:m:s",
    "yyyy/M/d'T'H:m:s",
    "yyyy-M-d' 'H:m:s",
    "yyyy/M/d' 'H:m:s"
  ).map(makeDateTimeFormatter(_))

  private def parseTime(s: String) = {
    import scala.util.Try
    val i = for {
      f <- timeFormats.iterator
      i <- Try(Instant.from(f.parse(s))).toOption
    } yield i
    require(i.hasNext, s"Illegal date string '${s}'")
    i.next
  }

  private def ntpEncodedMicro(i: Instant): Long = {
    val s = (i.getEpochSecond + 2208988800L) << 32
    val f = ((i.getNano.toLong << 32) / 1000000000L) & 0xFFFFF800L
    s | f
  }

  private def ntpEncodedNano(i: Instant): Long = {
    val s = (i.getEpochSecond + 2208988800L) << 32
    val f = ((i.getNano.toLong << 32) / 1000000000L) & 0xFFFFFFFFL
    s | f
  }

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
