// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.ipfix


import java.nio.{ByteBuffer,BufferUnderflowException}
import scala.collection.immutable.{Set => ScalaSet, TreeMap}
import scala.collection.mutable.ArrayBuffer


/**
  * A class to represent the contents of an IPFIX SubTemplateMultiList
  * structured data.
  *
  * A SubTemplateMultiList contains a zero or more instances of
  * [[Record Records]] that match one or more IPFIX [[Template
  * Templates]].
  *
  * The [[SubTemplateMultiList$ companion object]] has factory methods
  * that return a [[CollectedSubTemplateMultiList]] instance when
  * reading a SubTemplateMultiList or a [[ExportSubTemplateMultiList]]
  * instance when generating a SubTemplateMultiList for writing.
  *
  */
sealed abstract class SubTemplateMultiList protected () extends ListElement()
{
  /** Number of octets that begin an STML */
  protected val headerLength = 1

  /** Number of octets that begin a group of records that share the same
    * Template (an StmlGroup). */
  protected val stmlGroupHeaderLength = 4

  /**
    * Appends the SubTemplateMultiList to a buffer for writing to an
    * IPFIX stream.  Assumes all [[Template Templates]] used by the
    * [[Record Records]] in the SubTemplateMultiList have already been
    * added to the [[Session]] and written to the buffer.
    */
  def toBuffer(outbuf: ByteBuffer, session: Session): ByteBuffer
  // abstract; updating description from ListElement.toBuffer

  // implements ListElement.octetLength
  def octetLength: Int = {
    var len = headerLength
    for (rec <- iterator) {
      len += rec.octetLength
    }
    len
  }

  override def toString(): String = {
    val sb = new StringBuilder()
    iterator.addString(sb, s"STML(Semantics: ${semantics} [", ", ", "])")
    sb.mkString
  }

  override def formatted: String =
    iterator.map(_.formatted).mkString(s"subTemplateMultiList($semantics [", ", ", "])")

}


/**
  * A [[SubTemplateMultiList]] factory.
  */
object SubTemplateMultiList {
  /**
    * Creates a new SubTemplateMultiList by reading data from a
    * buffer.
    *
    * @param buffer The ByteBuffer containing the data representing the list.
    * @param session The IPFIX session from which the buffer was read.
    */
  def fromBuffer(buffer: ByteBuffer, session: Session): SubTemplateMultiList =
    new CollectedSubTemplateMultiList(buffer, session)

  /**
    * Creates a SubTemplateMultiList to which [[Record Records]] may
    * be appended.
    *
    * @param semantics The semantics for elements of this list.
    */
  def apply(semantics: ListSemantics = ListSemantics.Undefined)
      : SubTemplateMultiList =
  {
    new ExportSubTemplateMultiList(semantics)
  }

  /**
    * Creates a SubTemplateMultiList from another
    * SubTemplateMultiList.  Does not do a deep copy of the list when
    * `deep` is `false`: the new list contains references to the same
    * items as the existing list.
    */
  def apply(stml: SubTemplateMultiList, deep: Boolean): SubTemplateMultiList =
    new ExportSubTemplateMultiList(stml, deep)

  /**
    * Creates a SubTemplateMultiList from another
    * SubTemplateMultiList.  Does not do a deep copy of the list; the
    * new list contains references to the same items as the existing
    * list.
    */
  def apply(stml: SubTemplateMultiList): SubTemplateMultiList =
    SubTemplateMultiList.apply(stml, false)

}


/**
  * The CollectedSubTemplateMultiList class is used when reading a
  * [[SubTemplateMultiList]] from a data stream.  The class delays
  * realizing the elements of this until they are requested.
  *
  * Use the methods in the [[SubTemplateMultiList]]
  * [[SubTemplateMultiList$ companion object]] to create a
  * CollectedSubTemplateMultiList instance.
  *
  * @param buffer The ByteBuffer containing the data representing the list.
  * @param session The IPFIX session from which the buffer was read.
  *
  * @see The [[SubTemplateMultiList$ SubTemplateMultiList companion
  * object]] for a factory method.
  */
final class CollectedSubTemplateMultiList(
  buffer: ByteBuffer,
  val session: Session)
    extends SubTemplateMultiList()
{
  // FIXME: This class should be hardened against empty groups (i.e.,
  // groups that mention a template but contain no records)

  if (buffer.remaining() < headerLength) {
    throw new IllegalSubTemplateMultiListException(
      "Not enough bytes for subTemplateMultiList " +
        s"(${buffer.remaining()} of ${headerLength})")
  }

  override protected val semanticId = readSemanticId(buffer)

  /** The List of STML Groups in this STML. */
  private[this] val groups: Array[CollectedStmlGroup] = {
    buffer.position(headerLength)
    val ab = ArrayBuffer.empty[CollectedStmlGroup]
    while ( buffer.hasRemaining() ) {
      val sg = new CollectedStmlGroup(buffer.slice())
      buffer.position(buffer.position() + sg.octets)
      ab += sg
    }
    ab.toArray
  }

  /** CollectedStmlGroup represents a Group of Records that share the
    * same Template within this CollectedSubTemplateMultiList. */
  private[this] final class CollectedStmlGroup(val groupbuf: ByteBuffer)
  {
    if (groupbuf.remaining() < stmlGroupHeaderLength) {
      throw new BufferUnderflowException()
    }
    val tid = 0xffff & groupbuf.getShort()
    val octets = 0xffff & groupbuf.getShort()
    if (octets < groupbuf.position()) {
      throw new IllegalSubTemplateMultiListException(
        s"Length specified within STML Group ($octets) is too short")
    }
    groupbuf.limit(octets)

    lazy val template = session.getTemplate(tid).getOrElse {
      throw new IllegalSubTemplateMultiListException(
        s"Cannot get template for id $tid")
    }
    lazy val containsList = template.containsList

    lazy val records: Array[Record] = {
      groupbuf.position(stmlGroupHeaderLength)
      val ab = ArrayBuffer.empty[CollectedRecord]
      while ( groupbuf.hasRemaining() ) {
        ab += template.readRecord(groupbuf, session, None)
      }
      ab.toArray
    }

    lazy val size = records.size
  }

  // implements ListElement.iterator
  def iterator: Iterator[Record] = new CollectedStmlRecordIter()

  /** A Record iterator that uses a CollectedStmlGroup iterator. */
  private[this] class CollectedStmlRecordIter() extends Iterator[Record] {
    private[this] val groupIter = groups.iterator
    private[this] var recIter = Option.empty[Iterator[Record]]

    def hasNext: Boolean = {
      (recIter.nonEmpty && recIter.get.hasNext) || groupIter.hasNext
    }
    def next(): Record = {
      var result = Option.empty[Record]
      while ( result.isEmpty &&
        (( recIter.nonEmpty && recIter.get.hasNext ) || groupIter.hasNext ))
      {
        if ( recIter.nonEmpty && recIter.get.hasNext ) {
          result = Option(recIter.get.next())
        } else if ( groupIter.hasNext ) {
          recIter = Option(groupIter.next().records.iterator)
          if ( recIter.get.hasNext ) {
            result = Option(recIter.get.next())
          }
        } else {
          recIter = None
        }
      }
      result.getOrElse{throw new IndexOutOfBoundsException("no more entries")}
    }
  }

  // implements ListElement.apply
  def apply(idx: Int): Record = {
    // this is slow but it shouldn't be called often
    var result = Option.empty[Record]
    val iter = groups.iterator
    var start = 0
    while ( result.isEmpty && iter.hasNext ) {
      val sg = iter.next()
      if ( idx - start < sg.size ) {
        result = Option(sg.records(idx - start))
      } else {
        start += sg.size
      }
    }
    result.getOrElse { throw new IndexOutOfBoundsException() }
  }

  // implements ListElement.size
  lazy val size: Int = {
    var sz = 0
    for ( sg <- groups ) {
      sz += sg.size
    }
    sz
  }

  // implements ListElement.allTemplates
  def allTemplates: ScalaSet[Template] = {
    var s = ScalaSet.empty[Template]
    for ( sg <- groups ) {
      if ( !sg.containsList ) {
        s = s + sg.template
      } else {
        for ( rec <- sg.records ) {
          s = s ++ rec.allTemplates
        }
      }
    }
    s
  }

  // implements ListElement.allBasicListElements
  final def allBasicListElements: ScalaSet[InfoElement] = {
    var s = ScalaSet.empty[InfoElement]
    for {
      sg <- groups.iterator
      if sg.containsList
      rec <- sg.records
    } {
      s = s ++ rec.allBasicListElements
    }
    s
  }

  // implements SubTemplateMultiList.toBuffer
  def toBuffer(outbuf: ByteBuffer, session: Session): ByteBuffer = {
    outbuf.put(semanticId.toByte)
    for ( sg <- groups ) {
      if ( !sg.containsList ) {
        sg.groupbuf.position(stmlGroupHeaderLength)
        outbuf.putShort(session(sg.template).toShort)
          .putShort(sg.octets.toShort)
          .put(sg.groupbuf)
      } else {
        // store 0 for the length, write the record, update length
        // (since only template IDs should change, the final length
        // should be equal to sg.octets)
        val pos = outbuf.position()
        outbuf.putShort(session(sg.template).toShort)
          .putShort(0.toShort)
        for ( rec <- sg.records ) {
          rec.toBuffer(outbuf, session)
        }
        outbuf.putShort(pos + 2, ((outbuf.position() - pos).toShort))
      }
    }
    outbuf
  }

}


/**
  * The ExportSubTemplateMultiList class is used to incrementally
  * build a SubTemplateMultiList and export it to a stream.
  *
  * Use the methods in the [[SubTemplateMultiList]]
  * [[SubTemplateMultiList$ companion object]] to create an
  * ExportSubTemplateMultiList instance.
  *
  * @param groups A Map from Record Index to the ExportStmlGroup
  * containing that Record and others that use the same Template
  * @param semanticId The semantics for elements of this list.
  *
  * @see The [[SubTemplateMultiList$ SubTemplateMultiList companion
  * object]] for factory methods.
  */
final class ExportSubTemplateMultiList private (
  private[this] var groups: TreeMap[Int, ExportSubTemplateMultiList.ExportStmlGroup],
  protected val semanticId: Short)
    extends SubTemplateMultiList()
{
  import ExportSubTemplateMultiList.ExportStmlGroup

  /**
    * @param semantics The semantics for elements of this list.
    */
  def this(semantics: ListSemantics) =
    this(TreeMap.empty[Int, ExportSubTemplateMultiList.ExportStmlGroup],
      semantics.value)

  /**
    * Creates a SubTemplateMultiList from another SubTemplateMultiList.  When
    * `deep` is true, creates a deep copy of the list.  When `deep` is false,
    * the new list contains references to the same items as the existing list.
    */
  def this(stml: SubTemplateMultiList, deep: Boolean = false) =
    this(ExportSubTemplateMultiList.getRecords(stml, deep),
      stml.semantics.value)

  private[this] var recCount =
    if ( groups.isEmpty ) { 0 } else { 1 + groups.last._2.endIndex }

  /**
    * Appends a record to the list.
    */
  def append(rec: Record): ExportSubTemplateMultiList = {
    if ( groups.nonEmpty && groups.last._2.template == rec.template ) {
      // append to last group
      groups.last._2.append(rec)
    } else {
      // create a new group
      val sg = new ExportStmlGroup(rec, recCount)
      groups = groups + ((sg.startIndex, sg))
    }
    recCount += 1
    this
  }

  /**
    * Updates a record in the list, replacing the previous record
    * unless `idx` equals `size`, in which case the record is
    * appended.  The valid range for `idx` is 0 to the size of this
    * list, inclusive.
    */
  def update(idx: Int, rec: Record): Unit = {
    if ( idx == recCount ) {
      append(rec)
    } else {
      val sg = groups.find(
        kv => { kv._1 <= idx && idx <= kv._2.endIndex }
      ).map { _._2 }.getOrElse{
        throw new IndexOutOfBoundsException("bad index")
      }
      if ( sg.template == rec.template ) {
        // same template; replace record
        sg.records.update(idx - sg.startIndex, rec)
      } else {
        val sgNew = new ExportStmlGroup(rec, idx)
        val split = sg.records.splitAt(idx - sg.startIndex)
        // keep the records in "split._1" and "split._2.tail".  the
        // record at "split._2.head" is the one being replaced.
        if ( split._2.nonEmpty ) {
          val t = split._2.tail
          if ( t.nonEmpty ) {
            val sgTail = new ExportStmlGroup(t.head, idx + 1)
            sgTail.records = t
            sgTail.endIndex = sg.endIndex
            groups = groups + ((sgTail.startIndex, sgTail))
          }
        }
        if ( split._1.nonEmpty ) {
          sg.records = split._1
          sg.endIndex = idx - 1
        }
        groups = groups + ((sgNew.startIndex, sgNew))
      }
    }
  }

  // implements ListElement.iterator
  def iterator: Iterator[Record] = new ExportStmlRecordIter()

  // implements ListElement.apply
  def apply(idx: Int): Record = {
    val sg = groups.find(
      kv => { kv._1 <= idx && idx <= kv._2.endIndex }
    ).map { _._2 }.getOrElse{
      throw new IndexOutOfBoundsException("bad index")
    }
    sg.records(idx - sg.startIndex)
  }

  // implements ListElement.size
  def size: Int = recCount

  // implements ListElement.allTemplates
  def allTemplates: ScalaSet[Template] = {
    var s = ScalaSet.empty[Template]
    for ( sg <- groups.valuesIterator ) {
      if ( !sg.containsList ) {
        s = s + sg.template
      } else {
        for ( rec <- sg.records ) {
          s = s ++ rec.allTemplates
        }
      }
    }
    s
  }

  // implements ListElement.allBasicListElements
  final def allBasicListElements: ScalaSet[InfoElement] = {
    var s = ScalaSet.empty[InfoElement]
    for ( sg <- groups.valuesIterator ) {
      if ( sg.containsList ) {
        for ( rec <- sg.records ) {
          s = s ++ rec.allBasicListElements
        }
      }
    }
    s
  }

  // implements SubTemplateMultiList.toBuffer
  def toBuffer(outbuf: ByteBuffer, session: Session): ByteBuffer = {
    outbuf.put(semanticId.toByte)
    for (sg <- groups.valuesIterator) {
      val pos = outbuf.position()
      // write template id and 0 as a placeholder for length
      outbuf.putShort(session(sg.template).toShort).putShort(0.toShort)
      for (rec <- sg.records) {
        rec.toBuffer(outbuf, session)
      }
      // update length
      outbuf.putShort(pos + 2, ((outbuf.position() - pos).toShort))
    }
    outbuf
  }

  private[this] class ExportStmlRecordIter() extends Iterator[Record] {
    private[this] val groupIter = groups.valuesIterator
    private[this] var recIter = Option.empty[Iterator[Record]]

    def hasNext: Boolean = {
      (recIter.nonEmpty && recIter.get.hasNext) || groupIter.hasNext
    }
    def next(): Record = {
      if ( recIter.nonEmpty && recIter.get.hasNext ) {
        recIter.get.next()
      } else if ( groupIter.hasNext ) {
        recIter = Option(groupIter.next().records.iterator)
        recIter.get.next()
      } else {
        recIter = None
        throw new IndexOutOfBoundsException("no more entries")
      }
    }
  }

}


/**
  * Contains private definitions to support the
  * [[ExportSubTemplateMultiList]] class.
  */
object ExportSubTemplateMultiList {

  /**
    * Returns a Tree containing positional indexes mapped to
    * StmlGroups, where the Records in the StmlGroups are either
    * references to the Records in `stml` when `deep` is false or
    * complete copies of those Records when `deep` is `true`.  This is
    * a helper method for an auxiliary constructor.
    */
  private def getRecords(stml: SubTemplateMultiList, deep: Boolean):
      TreeMap[Int, ExportStmlGroup] =
  {
    var tree = TreeMap.empty[Int, ExportStmlGroup]
    var sg = Option.empty[ExportStmlGroup]
    var idx = -1
    val recs: Iterator[Record] = {
      if ( !deep ) {
        stml.iterator
      } else {
        for (r <- stml.iterator) yield Record(r, deep)
      }
    }
    for (r <- recs) {
      if ( sg.nonEmpty && sg.get.template == r.template ) {
        sg.get.records = sg.get.records :+ r
        idx += 1
      } else {
        if ( sg.nonEmpty ) {
          sg.get.endIndex = idx
        }
        idx += 1
        val s = new ExportStmlGroup(r, idx)
        tree = tree + ((s.startIndex, s))
        sg = Option(s)
      }
    }
    if ( sg.nonEmpty ) {
      sg.get.endIndex = idx
    }
    tree
  }

  /** ExportStmlGroup represents a Group of Records that share the
    * same Template within this ExportSubTemplateMultiList. */
  private class ExportStmlGroup (rec: Record, idx: Int) {
    val template: Template = rec.template
    lazy val containsList = template.containsList
    var startIndex = idx
    var endIndex = idx
    var records = ArrayBuffer.empty[Record] :+ rec

    def append(rec: Record): Unit = {
      records += rec
      endIndex += 1
    }
  }

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
