// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.ipfix

import scala.collection.mutable.{Set => ScalaSet}

import java.nio.file.StandardOpenOption
import java.nio.file.Path
import java.nio.file.Files
import java.nio.ByteBuffer
import java.nio.channels.{FileChannel,WritableByteChannel}


/**
  * ExportStreamTemplatesFirst extends the ExportStream class to
  * ensure that [[Template Templates]] occur in the output stream
  * before [[Record Records]].
  *
  * Instances of this class only produce output when the close()
  * method is called.
  *
  * Note: When template metadata or information element metadata is
  * being written to the stream, those records may appear before
  * some templates.
  */
class ExportStreamTemplatesFirst private (
  finalOutput: WritableByteChannel,
  session: Session,
  tempFile: Path,
  tempChannel: FileChannel)
    extends ExportStream(tempChannel, session)
{
  /** Keep a set of template IDs that contain meta data */
  private[this] val tidIsMeta = ScalaSet.empty[Int]

  // ensure outputStream points at the final location
  override def outputStream: WritableByteChannel = finalOutput

  // add the IDS for templates containing meta data to tidIsMeta
  override def add(template: Template): ExportStream = {
    val result = super.add(template)
    val id = session.getOrAdd(template)
    if ( template.isMetadataTemplate
      && (( !describeElements && template.isInfoElementMetadataTemplate )
        || ( !describeTemplates && template.isTemplateMetadataTemplate )))
    {
      tidIsMeta += id
    }
    result
  }

  // close the temporary file, then re-open it for reading and copy
  // its contents to the final location.
  override def close(): Unit = {
    super.close()

    templateDescriptionTID.foreach { tid => tidIsMeta += tid }
    elementDescriptionTID.foreach { tid => tidIsMeta += tid }

    val reorder = new ExportStreamTemplatesFirst.Reorder(
      session, makeTimestamp, tempFile, finalOutput, tidIsMeta)
    reorder.write()
  }

}

/**
  * An [[ExportStreamTemplatesFirst]] factory.
  */
object ExportStreamTemplatesFirst {
  /** An ExportStreamTemplatesFirst factory. */
  def apply(outputStream: WritableByteChannel, session: Session):
      ExportStreamTemplatesFirst =
  {
    val tempFile = Files.createTempFile("ExportStream", null)
    //tempFile.toFile().deleteOnExit()
    val tempChannel = FileChannel.open(tempFile, StandardOpenOption.WRITE)
    new ExportStreamTemplatesFirst(outputStream, session, tempFile, tempChannel)
  }


  /**
    * This helper class processes the data in the temporary output
    * file to generate the final output file.
    *
    * The class scans the temporary file twice, once to write
    * templates and meta-data records to the final output file, and a
    * second time to write ordinary records to the final output file.
    *
    * @param reorderSession The session object
    * @param tempFile The temporary output file
    * @param out The final output file
    * @param isMeta A Set of the meta-data Template IDs
    */
  private class Reorder(
    reorderSession: Session,
    makeTimestamp: ExportStream.MessageTimestamp,
    tempFile: Path,
    out: WritableByteChannel,
    isMeta: ScalaSet[Int])
  {
    /** The input buffer reading from `tempFile` */
    val inbuf = ByteBuffer.allocate(65536)
    /** The output buffer writing to `out` */
    val outbuf = ByteBuffer.allocate(65535)
    outbuf.position(Message.headerLength)
    /** A handle to the input file */
    val in = FileChannel.open(tempFile)

    /** To compute the sequence numbers in the output file, the records in
      * the temporary file must be decoded, and doing so requires an
      * instance of a [[Message]]. */
    private[this] class FakeMsg() extends Message {
      val exportTime = java.time.Instant.now()
      val exportEpochSeconds = exportTime.getEpochSecond()
      val exportEpochMilliseconds = exportTime.toEpochMilli()
      val observationID = reorderSession.observationDomain
      val session = reorderSession
      val sequenceNumber = 0L
      val dataRecordCount = 0
    }
    private[this] val fakemsg = new FakeMsg()

    /** The number of data records written to the stream; used for setting
      * the sequence number */
    private[this] var prevCount = 0

    /** The sum of the number of data records written to the stream and
      * those sitting in outbuf */
    private[this] var recCount = 0

    /** Whether this pass is writing templates (false) or records (true) */
    private[this] var writingRecords = false

    /** Fill in the [[Message]] header and write the output buffer
      * `outbuf` to the output stream `out`, then reset the buffer and
      * position it just after the Message header. */
    private[this] def writeBuffer(): Unit = {
      if (outbuf.position() > Message.headerLength) {
        val timestamp = makeTimestamp.stamp()
        outbuf.flip()
        outbuf.putShort(0, IPFIX_VERSION.toShort)
          .putShort(2, outbuf.limit().toShort)
          .putInt(4, timestamp.toInt)
          .putInt(8, prevCount)
          .putInt(12, reorderSession.observationDomain)
        prevCount = recCount
        out.write(outbuf)
        outbuf.clear()
          .position(Message.headerLength)
      }
      ()
    }

    /** Copy the contents of `buf`, representing an [[IpfixSet]], to the output
      * buffer `outbuf`, where `count` is the number of data records
      * in the buffer or 0 for a template set.  May write the output
      * buffer if there is not enough room for `buf`.  */
    private[this] def writeSet(buf: ByteBuffer, count: Int): Unit = {
      if (outbuf.remaining() < buf.remaining()) {
        writeBuffer()
      }
      outbuf.put(buf)
      recCount += count
    }

    /** Write an empty [[Message]] to the output stream `out`. */
    private[this] def writeEmptyMessage(): Unit = {
      val timestamp = makeTimestamp.stamp()
      outbuf.putShort(0, IPFIX_VERSION.toShort)
        .putShort(2, Message.headerLength.toShort)
        .putInt(4, timestamp.toInt)
        .putInt(8, prevCount)
        .putInt(12, reorderSession.observationDomain)
        .limit(Message.headerLength)
        .flip()
      out.write(outbuf)
      outbuf.clear()
        .position(Message.headerLength)
      ()
    }

    /** Read the [[IpfixSet]] having ID `id` from `buf` and call writeSet() to
      * copy it the output buffer if appropriate.  Expects `buf` to be
      * a slice containing just the Set and the complete Set. */
    private[this] def readSet(buf: ByteBuffer, id: Int): Unit = {
      if ( id < MIN_TEMPLATE_ID ) {
        // a template set
        if ( !writingRecords ) {
          writeSet(buf, 0)
        }
      } else {
        // a record set; write it when not writing records and it is
        // meta data, or when writing records and it is not meta data
        if ( isMeta.contains(id) ^ writingRecords ) {
          // need to decode it to get number of records in order to
          // calculate the sequence number
          val set = IpfixSet.fromBuffer(buf, fakemsg)
          buf.rewind()
          writeSet(buf, set.size)
        }
      }
    }

    /** Read the [[Message]] from `buf` and process the [[IpfixSet IpfixSets]]
      * it contains.  Expects `buf` to be a slice containing just the
      * Message and the complete Message. */
    private[this] def readMessage(buf: ByteBuffer): Unit = {
      assert(buf.remaining() >= Message.headerLength)
      if (buf.remaining() < Message.headerLength + IpfixSet.headerLength) {
        if (buf.remaining() == Message.headerLength) {
          // write empty message
          writeEmptyMessage()
        } else {
          throw new RuntimeException("FIXME")
        }
      } else {
        buf.position(Message.headerLength)
        do {
          val pos = buf.position()
          val id: Int = 0xffff & buf.getShort(pos)
          val len: Int = 0xffff & buf.getShort(pos + 2)
          if (buf.remaining() < len) {
            throw new RuntimeException("FIXME")
          } else if (len < IpfixSet.headerLength) {
            throw new RuntimeException("FIXME")
          } else {
            val b = buf.slice()
            b.limit(len)
            readSet(b, id)
            buf.position(pos + len)
          }
        } while (buf.remaining() >= IpfixSet.headerLength)
        if (buf.remaining() != 0) {
          throw new RuntimeException("FIXME")
        }
      }
    }

    /** Process the data from the temporary file `in` and copy the
      * appropriate parts to the output file `out`.  The temporary
      * file position should be at the start of the file. */
    private[this] def readInput(): Unit = {
      inbuf.clear()
      // get a block of data from the file
      while (in.read(inbuf) >= 0 || inbuf.position() != 0) {
        inbuf.flip()
        if (inbuf.remaining() < Message.headerLength) {
          throw new RuntimeException("FIXME")
        }
        // process all messages in this block
        var continue = true
        do {
          // at start of a message
          assert(IPFIX_VERSION == (0xffff & inbuf.getShort(inbuf.position())))
          val pos = inbuf.position()
          val len: Int = 0xffff & inbuf.getShort(pos + 2)
          if (inbuf.remaining() < len) {
            continue = false
          } else if (len < Message.headerLength) {
            throw new RuntimeException("FIXME")
          } else {
            val b = inbuf.slice()
            b.limit(len)
            readMessage(b)
            inbuf.position(pos + len)
            continue = (inbuf.remaining() >= Message.headerLength)
          }
        } while (continue)
        inbuf.compact()
      }
    }

    /** Process the temporary file twice and write the data to the output
      * file. */
    def write(): Unit = {
      // process the templates
      readInput()
      // process the records
      writingRecords = true
      in.position(0)
      readInput()
      // flush any pending data
      writeBuffer()
    }
  }

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
