// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.ipfix

import com.typesafe.scalalogging.StrictLogging
import java.util.BitSet
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.Queue


/**
  * A session contains information about a single IPFIX session, from
  * a single observation Domain.  Specifically, a session takes care
  * of handling the template library that maps template IDs to
  * [[Template templates]], and it keeps track of message sequence
  * numbers.
  *
  * The Session class's contructor is not accessible.  Instead
  * these classes must be used: The [[StreamSession]] is used when
  * reading data from a file or from a TCP stream.  The
  * [[DatagramSession]] is used when reading data over UDP.
  *
  * @param infoModel The information model used by the session.
  * @param transport The source information for this session.  If the
  * source is a file, `transport` is a String containing the file
  * name.  If the source is a network connection, `transport` is an
  * InetSocketAddress representing the remote side of this session's
  * connection.
  * @param observationDomain The observation domain for the session,
  * an unsigned 32-bit value.  Transport sessions that have a single
  * observation domain typically have an observation domain ID of
  * zero.
  *
  * @see [[Session$ The companion object]] for additional details.
  */
sealed class Session private (
  final val infoModel: InfoModel,
  final val transport: AnyRef,
  final val observationDomain: Int)
    extends StrictLogging
{
  // import from the companion object
  import Session.{SequenceCallback,TemplateCallback}

  /** Maps a templateId to a Template.  "protected" since
    * ReadOnlySession needs access */
  protected final var idGetTmpl = TrieMap.empty[Int, Template]

  /** Maps a Template to its templateId in this Session. */
  protected final var tmplGetId = TrieMap.empty[Template, Int]

  /** Maps a templateId to its metadata */
  protected final var templateMetadata = TrieMap.empty[Int, TemplateMetadata]

  /** Callbacks to invoke when a template arrives or is withdrawn. */
  @volatile
  private final var templateCallbacks = Queue.empty[TemplateCallback]

  /** Callbacks to invoke for an unexpected sequence number. */
  @volatile
  private final var sequenceCallbacks = Queue.empty[SequenceCallback]

  /** A read-only copy of this Session that may not be modified.
    * Created as needed when getPersistentSession() is called. */
  private final var persistent: Option[Session] = None

  /**
    * The sequence number that is expected for the next message in
    * this session.
    *
    * It is value a value between 0 and 2^32-1, or -1 if no message
    * has yet been observed in this session.
    */
  var expectedSequence: Long = -1

  /** For determining which Template IDs are in use.  (This is the Java
    * BitSet since the Scala BitSet is no better than a Set.) */
  private[this] val usedIds = new BitSet(MIN_TEMPLATE_ID)
  // Reserved IDs (note: leaves MIN_TEMPLATE_ID unset)
  usedIds.set(0, MIN_TEMPLATE_ID)


  /**
    * Auxiliary constructor
    *
    * Creates a new session from the given session group and
    * observation domain ID.  The new session creates a new
    * information model that inherits from the SessionGroup's
    * information model.
    *
    * @param group  the session group that owns the session
    * @param id     the observation domain ID
    */
  protected def this(group: SessionGroup, id: Int) =
    this(InfoModel.inheritInfoModel(group.infoModel), group.transport, id)


  /**
    * Notifies all template callbacks that a template was withdrawn.
    */
  private[this] def notifyWithdrawal(t: Template, id: Int): Unit = {
    for ( cb: TemplateCallback <- templateCallbacks ) {
      cb.withdrawnTemplate(this, t, id)
    }
  }

  /**
    * Notifies all template callbacks that a template was added.
    */
  private[this] def notifyNew(t: Template, id: Int): Unit = {
    for ( cb: TemplateCallback <- templateCallbacks ) {
      cb.newTemplate(this, t, id)
    }
  }

  /**
    * Internal function that removes a template when a withdrawal
    * template is received on the input stream.  That is, when add()
    * is called with a withdrawal template.
    */
  // This function is protected because the DatagramSession class
  // overrides it in order to disable it per RFC 7011, Sec 8.4.
  protected def withdraw(t: Template, tid: Int): Option[Template] = {
    assert(t.isWithdrawalTemplate)
    if ( tid != TEMPLATE_SET_ID && tid != OPTIONS_TEMPLATE_SET_ID ) {
      // A request to withdraw a single template

      // note: should log when id is not found
      remove(tid)
    } else {
      // A request to withdraw all templates (when tid is
      // TEMPLATE_SET_ID) or all option templates (when tid is
      // OPTIONS_TEMPLATE_SET_ID).
      val withdrawn = synchronized {
        // partition into (1)those templates whose isOptionsTemplate
        // status is the opposite of `tid` and (2)those that match
        val partitioned = idGetTmpl.partition(
          keyval => keyval._2.isOptionsTemplate ^ (tid == TEMPLATE_SET_ID))
        // keep the templates in the first partition
        idGetTmpl = partitioned._1
        // regenerate the tmpl->id mapping structure
        tmplGetId = TrieMap.empty[Template, Int] ++ {
          for ( (id, tmpl) <- idGetTmpl ) yield (tmpl, id)
        }
        // set `withdrawn` to those templates in second partition
        partitioned._2
      }
      // inline notifyWithdrawal
      for {
        cb <- templateCallbacks
        (id, tmpl) <- withdrawn
      } {
        cb.withdrawnTemplate(this, tmpl, id)
      }
      None
    }
  }

  /**
    * Registers a new template callback with this session.  A session
    * can have any number of callbacks registered to it.  Each
    * registered callback will be called whenever a template is added
    * or removed from the session template library.
    *
    * @param callback  the template callback
    */
  def register(callback: TemplateCallback): Unit = {
    templateCallbacks = templateCallbacks :+ callback
  }

  /**
    * Registers multiple templates callbacks with this session.
    *
    * @param callbacks  the template callbacks
    */
  def register(callback: Iterable[TemplateCallback]): Unit = {
    templateCallbacks = templateCallbacks ++ callback
  }

  /**
    * Registers a new sequence callback with this session.  A session
    * can have any number of callbacks registered to it.  Each
    * registered callback will be called whenever a message arrives
    * with an unexpected sequence number.
    *
    * @param callback  the sequence number callback
    */
  def register(callback: SequenceCallback): Unit = {
    sequenceCallbacks = sequenceCallbacks :+ callback
  }

  /**
    * Returns the ID for `template` in this Session, adding `template` to this
    * Session if it does not exist.
    *
    * If `template` must be added to the Session, an arbitrary ID is used when
    * `tid` is `None`.  If `tid` is not `None`, that ID is used if possible;
    * otherwise an arbitrary ID is used.  Use the `add` method if `template`
    * must use a particular ID.
    *
    * Takes no action and returns 0 if `template.size` is 0.
    *
    * @param template The Template to get or add
    * @param tid An optional ID for `template` if it does not exist in
    * the Session
    * @return The ID of `template` in this Session
    */
  def getOrAdd(template: Template, tid: Option[Int] = None): Int = synchronized
  {
    tmplGetId.get(template).getOrElse {
      if ( template.size == 0 ) {
        0
      } else {
        val id =
          if ( tid.isEmpty || usedIds.get(tid.get) ) {
            findUnusedId(MIN_TEMPLATE_ID, MAX_TEMPLATE_ID).getOrElse{
              throw new RuntimeException("Session Template table is full") }
          } else {
            val t = tid.get
            if ( t < MIN_TEMPLATE_ID || t > MAX_TEMPLATE_ID ) {
              throw new RuntimeException(s"Invalid template id ${t}%#06x")
            }
            t
          }
        tmplGetId.update(template, id)
        idGetTmpl.update(id, template)
        usedIds.set(id)
        id
      }
    }
  }

  /** Adds a [[Template]] to this Session using `tid` as the templateId
    * and replacing the Template that previously used that ID, if any.
    * If a [[Template]] is replaced, it is returned.
    *
    * To add `template` only if it is not already present, use
    * `getOrAdd`.
    *
    * If `template.size` is 0, withdraws and returns the Template whose ID is
    * `tid`.  The Template if not withdrawn if this is a [[DatagramSession]].
    *
    * @throws Session.TemplateInterferenceException when a template is
    *     replaced that may not be replaced, by policy.
    *
    * @throws java.lang.RuntimeException when the [[Template]]'s session
    *     is not this session.
    */
  def add(template: Template, tid: Int): Option[Template] = {
    if ( tid < MIN_TEMPLATE_ID || tid > MAX_TEMPLATE_ID ) {
      throw new RuntimeException(f"Invalid template id ${tid}%#06x")
    }
    if ( template.isWithdrawalTemplate ) {
      return withdraw(template, tid)
    }
    // insert/replace the template
    val (same, old) = synchronized {
      val result = idGetTmpl.put(tid, template)
      if ( template == result.orNull ) {
        // New template is the same as the old
        (true, None)
      } else {
        tmplGetId.update(template, tid)
        usedIds.set(tid)
        (false, result)
      }
    }
    if ( !same ) {
      old.foreach(t => notifyWithdrawal(t, tid))
      persistent = None
      notifyNew(template, tid)
    }
    old
  }


  /**
    * Updates the metadata for a particular template in this Session.  Ignores
    * `metadata` if its template ID is invalid.
    */
  def addTemplateMetadata(metadata: TemplateMetadata): Unit = {
    val tid = metadata.templateId
    if ( tid >= MIN_TEMPLATE_ID && tid <= MAX_TEMPLATE_ID ) {
      templateMetadata.update(tid, metadata)
    }
  }

  /**
    * Gets the metadata for the [[Template]] whose ID is `tid` as an
    * [[scala.Option Option]].
    */
  def getTemplateMetadata(tid: Int): Option[TemplateMetadata] = {
    templateMetadata.get(tid)
  }

  /**
    * Removes a template from this session.
    *
    * @param tid  the template ID of the template to remove
    * @return the template that was removed, or `None` if none
    */
  def remove(tid: Int): Option[Template] = {
    val result = synchronized {
      val old = idGetTmpl.remove(tid)
      old.foreach { t => {
        tmplGetId.remove(t)
        usedIds.clear(tid)
        persistent = None }
      }
      old
    }
    result.foreach { t => notifyWithdrawal(t, tid) }
    result
  }

  /** Removes a template from this session.
    *
    * @param template The template to remove
    * @return the template that was removed, or `None` if none
    */
  def remove(template: Template): Option[Template] = {
    val result = synchronized {
      val old = tmplGetId.remove(template)
      old.foreach { tid => {
        idGetTmpl.remove(tid)
        usedIds.clear(tid)
        persistent = None }
      }
      old
    }
    result.map { tid => {
      notifyWithdrawal(template, tid)
      template }
    }
  }


  /** The Template from the template library with this template ID.
    *
    * @throws NoTemplateException if the Session does not have a
    *     Template with that ID.
    *
    * @see [[getTemplate]]
    * @since 1.3.1
    */
  final def apply(tid: Int): Template = synchronized {
    idGetTmpl.get(tid).getOrElse {
      throw new NoTemplateException(tid)
    }
  }

  /** The ID used by `template` within this Session.
    *
    * @throws NoTemplateException if the Session does not know about that
    *     Template.
    *
    * @see [[getId]]
    *
    * @since 1.3.1
    */
  final def apply(template: Template): Int = synchronized {
    tmplGetId.get(template).getOrElse {
      throw new NoTemplateException(template)
    }
  }

  /**
    * Gets a Template from the template library based on the template
    * ID as an Option.
    *
    * @param tid  the template ID
    * @return The template associated with the template ID, or `None`
    */
  final def getTemplate(tid: Int): Option[Template] = synchronized {
    idGetTmpl.get(tid)
  }

  /**
    * Gets the ID for the Template within this session as an Option.
    */
  final def getId(template: Template): Option[Int] = synchronized {
    tmplGetId.get(template)
  }

  /**
    * Returns a read-only copy of this session that will not change.
    *
    * @return the persistent session
    */
  def getPersistentSession(): Session = synchronized {
    persistent.getOrElse({
      val s = new Session.ReadOnlySession(this)
      s.idGetTmpl ++= this.idGetTmpl
      s.tmplGetId ++= this.tmplGetId
      s.templateMetadata ++= this.templateMetadata
      persistent = Option(s)
      s
    })
  }


  /**
    * Returns an unused Template ID
    */
  private[this] def findUnusedId(min: Int, max: Int): Option[Int] = {
    // Attempts to return an ID larger than any that has been used
    // before; when that valid is too large, uses the lowest unused ID
    var tid = usedIds.length
    if ( tid > MAX_TEMPLATE_ID ) {
      tid = usedIds.nextClearBit(MIN_TEMPLATE_ID)
    }
    if ( tid > MAX_TEMPLATE_ID ) {
      None
    } else {
      Option(tid)
    }
  }

  /**
    * Find a template ID not currently being used by this session.
    * Returns None when the Session has no available IDs.  (Not
    * necessarily deterministic.)
    *
    * @return a currently unused template ID
    */
  final def findUnusedId(): Option[Int] = synchronized {
    findUnusedId(MIN_TEMPLATE_ID, MAX_TEMPLATE_ID)
  }


  /**
    * Check for missing or out of sequence records.  If found, invoke
    * any SequenceCallback objects that have been registered and log a
    * message.  A log message is not written if no SequenceCallback
    * objects have been registered.
    */
  /* synchronized */
  final def noteMessage(m: Message): Unit = {
    val received: Long = m.sequenceNumber
    if ( -1 == expectedSequence ) {
      // first sequence number
      expectedSequence = received
    } else {
      // cache the value that was expected for this message
      val expected = expectedSequence
      // compute the sequence number expected on the next message
      expectedSequence = (received + m.dataRecordCount) & 0xffffffffL
      // compare old expected value with received value
      if ( sequenceCallbacks.nonEmpty && received != expected ) {
        logger.debug(s"[${transport}] Domain ${observationDomain}: " +
          "Out of sequence message." +
          f" Expected ${expected}%#06x, got ${received}%#06x")
        for ( cb: SequenceCallback <- sequenceCallbacks ) {
          cb.outOfSequence(this, expected, received)
        }
      }
    }
  }

  /**
    * Returns an Iterator over the Templates in the Session.
    */
  /* synchronized */
  final def iterator: Iterator[Template] = {
    // put them into a separate list to allow for additions and
    // deletions while the iterator is active
    val list = List.empty[Template] ++ idGetTmpl.values
    list.iterator
  }

}


/**
  * Defines traits for classes that are used as callbacks on a
  * [[Session]].
  */
object Session {

  /**
    * A callback object that can be registered with a [[Session]] that
    * receives a notification whenever a message is encountered that
    * has an unexpected sequence number.  The sequence number of a
    * message should represent the number of records that have been
    * previously transmitted in this session, modulo 2<sup>32</sup>.
    *
    * An unexpected sequence number typically means that either one
    * or more messages have been dropped, or messages are arriving
    * out-of-order.
    */
  trait SequenceCallback {
    /**
      * Called when a message is encountered that has a unexpected
      * sequence number.
      *
      * @param session   the session for the message
      * @param expected  the expected sequence number
      * @param value     the encountered sequence number
      */
    def outOfSequence(session: Session, expected: Long, value: Long)
  }

  /**
    * A callback object that can be registered with a [[Session]] or a
    * [[SessionGroup]] that receives a notification about each new or
    * removed [[Template]] in the session's template library.
    */
  trait TemplateCallback {
    /**
      * Called after a new template is successfully added to a
      * session.
      *
      * @param session   the session being added to
      * @param template  the template that was added
      * @param id        the ID of the template that was added
      */
    def newTemplate(session: Session, template: Template, id: Int)

    /**
      * Called after a template is successfully removed from a session
      * for any reason.
      *
      * @param session   the session being removed from
      * @param template  the template that was removed
      * @param id        the ID of the template that was withdrawn
      */
    def withdrawnTemplate(session: Session, template: Template, id: Int)
  }

  /**
    * An exception that gets thrown when a template tries to replace
    * an existing template.
    */
  class TemplateInterferenceException(tid: Int) extends RuntimeException(
    f"Received template ${tid}%#06x, which does not match existing template")



  /**
    * A version of a Session that does not allow Templates to be added
    * or removed.
    */
  private final class ReadOnlySession(s: Session)
      extends Session(s.infoModel, s.transport, s.observationDomain)
  {
    override def getPersistentSession(): Session = this

    override def add(template: Template, tid: Int): Option[Template] =
      throw new UnsupportedOperationException("Read only session")

    override def getOrAdd(t: Template, tid: Option[Int]): Int =
      throw new UnsupportedOperationException("Read only session")

    override def remove(tid: Int): Option[Template] =
      throw new UnsupportedOperationException("Read only session")

    override def remove(template: Template): Option[Template] =
      throw new UnsupportedOperationException("Read only session")

    override def addTemplateMetadata(metadata: TemplateMetadata):
        Unit =
    {
      throw new UnsupportedOperationException("Read only session")
    }

  }
}


/**
  * There are two types of [[Session]]s: StreamSession and
  * [[DatagramSession]].
  *
  * A StreamSession implements [[Template]] instantiation and Template
  * withdrawal handling as dictated by sections 8 and 8.1 of RFC 7011.
  * In particular, [[Session.TemplateInterferenceException]] is thrown
  * when different templates use the same template id.
  *
  * A StreamSession is created by a [[SessionGroup]] with the
  * `streamSemantics` parameter set to `true`.
  *
  * @param group  The session group that owns the session.
  * @param id     The observation domain ID.
  *
  * @throws Session.TemplateInterferenceException when different
  *     templates use the same template id.
  */
final class StreamSession(group: SessionGroup, id: Int)
    extends Session(group, id)
{
  override def add(t: Template, tid: Int): Option[Template] = {
    val optionOld = super.add(t, tid)
    optionOld match {
      case None =>
        if (t.isWithdrawalTemplate) {
          logger.warn(
            s"[${transport}] Domain $id: Received Template withdrawal" +
              f" for non-existent Template ${tid}%#06x")
        }
      case Some(old) =>
        if (t.isWithdrawalTemplate) {
          // do nothing; the template has been removed
        } else if (t.equals(old)) {
          /*
           * FIXME: I think this is impossible since the add() method
           * in Session returns None when it gets a duplicate
           * template
           */
          logger.warn(s"[${transport}] Domain $id:" +
            f" Received re-transmitted template ${tid}%#06x")
          return None
        } else {
          logger.error(
            f"[${transport}] Domain $id: Received template ${tid}%#06x" +
              " does not match existing template with same ID")
          throw new Session.TemplateInterferenceException(tid)
        }
    }
    optionOld
  }
}


/**
  * There are two types of [[Session Sessions]]: [[StreamSession]] and
  * DatagramSession.
  *
  * A DatagramSession implements template instantiation and template
  * withdrawal handling as dictated by section 8.4 of RFC 7011.
  *
  * A DatagramSession is created by a [[SessionGroup]] with the
  * `streamSemantics` parameter set to `false`.
  *
  * @param group  the session group that owns the session
  * @param id     the observation domain ID
  */
final class DatagramSession(group: SessionGroup, id: Int)
    extends Session(group, id) with StrictLogging
{
  override protected def withdraw(t: Template, tid: Int): Option[Template] = {
    assert(t.isWithdrawalTemplate)
    logger.warn(s"[${transport}] Domain $id:" +
      f" Ignoring withdrawal template for ${tid}%#06x")
    None
  }
}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
