// Copyright 2015-2022 by Carnegie Mellon University
// See license information in LICENSE.txt

package org.cert.netsa.io.silk

import scala.collection.mutable.ListBuffer
import java.io.{EOFException, InputStream}
import org.apache.hadoop.conf.Configuration

import io.{BufferReader, LzoInputStreamBuffer, RawInputStreamBuffer,
  SnappyInputStreamBuffer, ZlibInputStreamBuffer}
import io.BufferUtil.getInt32
import io.ipset.{
  IPSetV2Reader, IPSetV3IPv4Reader, IPSetV3IPv6Reader,
  IPSetV4IPv4Reader, IPSetV4IPv6Reader, IPSetV5Reader}

import org.cert.netsa.data.net.IPBlock


/**
  * A reader of binary SiLK IPset files. This is usable as an
  * [[scala.collection.Iterator Iterator]] over
  * [[org.cert.netsa.data.net.IPBlock IPBlock]] objects.
  *
  * @example This example uses the single argument form of the
  * [[IPSetReader$ companion object's]] `ofInputStream()` method to
  * read the IPset file "example.set"; the code may be used outside of
  * Hadoop.
  * {{{
  * val stream = new java.io.FileInputStream("example.set")
  * val ipset = IPSetReader.ofInputStream(stream)
  * ipset.hasNext
  * val ipblock = ipset.next()
  * println(ipblock.min + "/" + ipblock.prefixLength)
  * }}}
  *
  * @see [[IPSetReader$ the companion object]] for more details
  */
abstract class IPSetReader protected (
  protected[this] val reader: BufferReader,
  protected[this] val header: Header
) extends Iterator[IPBlock]
{
  /**
    * Whether the IPSetReader returns an [[scala.collection.Iterator
    * Iterator]] over [[org.cert.netsa.data.net.IPv6Block IPv6Blocks]]
    * or [[org.cert.netsa.data.net.IPv4Block IPv4Blocks]].
    *
    * @return `true` if the IPBlock objects are IPv6Block
    */
  def containsIPv6: Boolean

  /**
    * Tests whether this iterator can provide another IPBlock.
    *
    * @return `true` if a subsequent call to `next()` will yield an
    * element, `false` otherwise.
    */
  override def hasNext: Boolean

  /**
    * Produces the next IPBlock of this iterator.
    *
    * @return the next IPBlock of this iterator if `hasNext` is true
    *
    * @throws [[java.util.NoSuchElementException]] when the iterator
    * is depleted
    */
  override def next(): IPBlock

  /**
    * Octet-length of the 256-Bitmap
    */
  protected val bitmap256Length = 32

  /**
    * For IPset file formats v4 and v5, a special value that indicates
    * that a 256-long bitmap follows the IP address
    */
  protected val cidrBitmapFollows = 0x81

  /**
    * Whether to byte-swap the bytes as they are read from the buffer
    */
  protected[this] val swap = !header.isBigEndian

  /**
    * Current buffer of bytes being processed
    */
  protected[this] var buffer = Array.empty[Byte]
  /**
    * Current position (octet offset) in the current buffer.
    */
  protected[this] var bufOffset = 0
  /**
    * Length of the current buffer in octets.
    */
  private[this] var bufLength = 0
  /**
    * Whether the end of the input stream has been reached.
    */
  private[this] var endOfFile = false

  /**
    *  Return a List of Pairs of Integers.
    *
    *  In each pair, the first value is a location of a high bit in the
    *  bitmap, and the second value is the number of consecutive high bits
    *  that is also a power of 2.  The range of the first value is 0 to 255;
    *  the range of the seocond is 1 to 256.
    */
  protected[this] def handleBitmap256(buffer: Array[Byte], offset: Int, swap: Boolean): List[List[Int]] = {
    val buf = new ListBuffer[List[Int]]
    var i = 0
    val x = List.range(0, 8) map(j => getInt32(buffer, offset + 4 * j, swap))
    do {
      if (0xffffffff == x(i)) {
        if ((0x1 == (i & 0x1)) || (0xffffffff != x(i + 1))) {
          //  Cannot grow: i is odd or next number is not full
          buf += List(32 * i, 32)
          i = i + 1
        } else if ((0x2 == (i & 0x2)) || (0xffffffff != x(i + 2))) {
          //  Cannot grow: i is divisible by 2 or 3rd not full
          buf += List(32 * i, 64)
          i = i + 2
        } else if (0xffffffff != x(i + 3)) {
          //  Cannot grow: fourth number not full
          buf += List(32 * i, 64)
          buf += List(32 * (i + 2), 32)
          i = i + 3
        } else if ((0x4 == (i & 0x4)) ||
                   (x.segmentLength({_ == 0xffffffff}, 4) < 4)) {
          //  Cannot grow: i is divisible by 4 or final 4 not all full
          buf += List(32 * i, 128)
          i = i + 4
        } else {
          return List(List[Int](0, 256))
        }
      } else {
        var pos = 0
        var y = x(i)
        while (y != 0) {
          if ((0 == (pos & 0xf)) &&
            ((0 == (y & 0xffff)) || (0xffff == (y & 0xffff)))) {
            // pos is divisible by 16 and there are 16 high or low bits
            if (0x1 == (y & 0x1)) {
              buf += List(pos + 32 * i, 16)
            }
            y = y >>> 16
            pos += 16
          } else if ((0 == (pos & 0x7))
            && ((0 == (y & 0xff)) || (0xff == (y & 0xff)))) {
            // pos is divisible by 8 and there are 8 high or low bits
            if (0x1 == (y & 0x1)) {
              buf += List(pos + 32 * i, 8)
            }
            y = y >>> 8
            pos += 8
          } else {
            // handle a 4 bit block
            val q = y & 0xf
            q match {
              case  1 => { buf += List(pos + 32 * i, 1) }
              case  2 => { buf += List(1 + pos + 32 * i, 1) }
              case  3 => { buf += List(pos + 32 * i, 2) }
              case  4 => { buf += List(2 + pos + 32 * i, 1) }
              case  5 => {
                buf += List(pos + 32 * i, 1)
                buf += List(2 + pos + 32 * i, 1) }
              case  6 => {
                buf += List(1 + pos + 32 * i, 1)
                buf += List(2 + pos + 32 * i, 1) }
              case  7 => {
                buf += List(pos + 32 * i, 2)
                buf += List(2 + pos + 32 * i, 1) }
              case  8 => { buf += List(3 + pos + 32 * i, 1) }
              case  9 => {
                buf += List(pos + 32 * i, 1)
                buf += List(3 + pos + 32 * i, 1) }
              case 10 => {
                buf += List(1 + pos + 32 * i, 1)
                buf += List(3 + pos + 32 * i, 1) }
              case 11 => {
                buf += List(pos + 32 * i, 2)
                buf += List(3 + pos + 32 * i, 1) }
              case 12 => { buf += List(2 + pos + 32 * i, 2) }
              case 13 => {
                buf += List(pos + 32 * i, 1)
                buf += List(2 + pos + 32 * i, 2) }
              case 14 => {
                buf += List(1 + pos + 32 * i, 1)
                buf += List(2 + pos + 32 * i, 2) }
              case 15 => { buf += List(pos + 32 * i, 4) }
              case _  =>
            }
            y = y >>> 4
            pos += 4
          }
        }
        i = i + 1
      }
    } while (i < 8)
    buf.toList
  }

  /**
    * Takes the number of consecutive high bits returned by
    * handleBitmap256 and returns the CIDR prefix for IPv6 addresses.
    */
  protected[this] def lenToCidrV6(c: Int) : Int = {
    return c match {
      case   1 => 128
      case   2 => 127
      case   4 => 126
      case   8 => 125
      case  16 => 124
      case  32 => 123
      case  64 => 122
      case 128 => 121
      case 256 => 120
      case _   =>
        throw new Exception(s"Length ${c} is not a power of 2")
    }
  }

  /**
    * Takes the number of consecutive high bits returned by
    * handleBitmap256 and returns the CIDR prefix for IPv6 addresses.
    */
  protected[this] def lenToCidrV4(c : Int) : Int = {
    return c match {
      case   1 => 32
      case   2 => 31
      case   4 => 30
      case   8 => 29
      case  16 => 28
      case  32 => 27
      case  64 => 26
      case 128 => 25
      case 256 => 24
      case _   =>
        throw new Exception(s"Length ${c} is not a power of 2")
    }
  }


  /**
    * Return true if 'avail' octets are available in the buffer,
    * reading more data from the stream if necessary.
    */
  protected[this] def checkAvailable(avail: Int): Boolean = {
    if ( !endOfFile ) {
      if ( bufOffset + avail > bufLength ) {
        try {
          val (newBuffer, newLength) = reader.getNextBuffer()
          if ( bufOffset == bufLength ) {
            buffer = newBuffer
            bufLength = newLength
            bufOffset = 0
          } else {
            buffer = buffer.drop(bufOffset) ++ newBuffer
            bufLength = bufLength - bufOffset + newLength
            bufOffset = 0
          }
        } catch {
          case _: EOFException => {
            reader.close()
            buffer = Array.empty
            endOfFile = true
          }
        }
      }
    }
    !endOfFile
  }

  /**
    * Read and throw away 'skipLen' bytes from the stream.  An
    * end-of-file error in this function is not caught.
    */
  protected[this] def skipBytes(skipLen: Int) = {
    var toSkip = skipLen
    while ( toSkip > 0 ) {
      if ( bufOffset + toSkip <= bufLength ) {
        bufOffset = bufOffset + toSkip
        toSkip = 0
      } else {
        toSkip = toSkip - (bufLength - bufOffset)
        val (newBuffer, newLength) = reader.getNextBuffer()
        buffer = newBuffer
        bufLength = newLength
      }
    }
  }


}


/**
  * The IPSetReader object provides support for creating an
  * [[IPSetReader]].
  */
object IPSetReader {

  /**
    * Helper function for the `ofInputStream()` methods below.  Reads
    * the file's header and ensures the values are corrected for an
    * IPset file.
    */
  private[this] def checkHeader(s: InputStream): (Header, Boolean) = {
    val header = Header.readFrom(s)
    var isIPv6 = false

    if (FileFormat.FT_IPSET != header.fileFormat) {
      throw new SilkDataFormatException(
        "File is not an IPset file")
    }
    if (1 != header.recordSize) {
      throw new SilkDataFormatException(
        "IPset file has unexpected record size: " + header.recordSize)
    }

    // verify the record version; if 2, return since nothing else to do
    header.recordVersion match {
      case 2 => return (header, isIPv6)
      case 3 =>
      case 4 =>
      case 5 =>
      case _ => throw new SilkDataFormatException(
        "IPset file has unexpected record version: " + header.recordVersion)
    }

    // get the IPset header entry
    val hentry_opt = header.headerEntries.collectFirst({
      case entry : HeaderEntry.IPSet => entry
    })
    val hentry = hentry_opt match {
      case None => throw new SilkDataFormatException(
        "File is missing the IPset header entry")
      case _ => hentry_opt.get
    }

    if (3 == header.recordVersion) {
      if (16 != hentry.childNode) {
        throw new SilkDataFormatException(
          "IPset file has unexpected child-per-node count: " + hentry.childNode)
      }
      val nodeIsV6 = {
        hentry.nodeSize match {
          case 96 => true
          case 80 => false
          case _  => throw new SilkDataFormatException(
            "IPset file has unexpected node size: " + hentry.nodeSize)
        }
      }
      val leafIsV6 = {
        hentry.leafSize match {
          case 24 => true
          case  8 => false
          case _  => throw new SilkDataFormatException(
            "IPset file has unexpected leaf size: " + hentry.leafSize)
        }
      }
      if (nodeIsV6 != leafIsV6) {
        throw new SilkDataFormatException(
          "IPset file has mismatched node and leaf sizes")
      }
      isIPv6 = nodeIsV6

    } else {
      // for v4 and v5 files, use leafSize to determine IPv4 or IPv6
      isIPv6 = hentry.leafSize match {
        case 16 => true
        case  4 => false
        case  _ => throw new SilkDataFormatException(
          "IPset file has unexpected leaf size: " + hentry.leafSize)
      }
      if ( !isIPv6 && 5 == header.recordVersion ) {
        throw new SilkDataFormatException(
          "IPset file has unexpected leaf size: " + hentry.leafSize)
      }

      // all other header-entyr fields are 0
      if (0 != hentry.childNode) {
        throw new SilkDataFormatException(
          "IPset file has unexpected child-per-node count: " + hentry.childNode)
      }
      if (0 != hentry.nodeCount) {
        throw new SilkDataFormatException(
          "IPset file has unexpected node count: " + hentry.nodeCount)
      }
      if (0 != hentry.nodeSize) {
        throw new SilkDataFormatException(
          "IPset file has unexpected node size: " + hentry.nodeSize)
      }
      if (0 != hentry.leafCount) {
        throw new SilkDataFormatException(
          "IPset file has unexpected leaf count: " + hentry.leafCount)
      }
      if (0 != hentry.rootIndex) {
        throw new SilkDataFormatException(
          "IPset file has unexpected root index: " + hentry.rootIndex)
      }
    }

    (header, isIPv6)
  }


  /**
    * Helper function for the `ofInputStream()` methods below.  Invokes
    * the appropriate constructor based on the settings in the file's
    * header.
    */
  private[this] def createReader(bufferReader: BufferReader, header: Header, isIPv6: Boolean): IPSetReader = {
    header.recordVersion match {
      case 2 => return new IPSetV2Reader(bufferReader, header)
      case 3 => if ( isIPv6 ) {
        return new IPSetV3IPv6Reader(bufferReader, header)
      } else {
        return new IPSetV3IPv4Reader(bufferReader, header)
      }
      case 4 => if ( isIPv6 ) {
        return new IPSetV4IPv6Reader(bufferReader, header)
      } else {
        return new IPSetV4IPv4Reader(bufferReader, header)
      }
      case 5 => return new IPSetV5Reader(bufferReader, header)
    }
  }


  /**
    * Creates and returns a reader from the provided input
    * stream. Does not support compressed data.
    *
    * @throws SilkDataFormatException if the input stream is
    *     malformed, is not an IPset, or uses compression.
    */
  def ofInputStream(s: InputStream): IPSetReader = {
    val (header, isIPv6) = checkHeader(s)
    val bufferSize = 65536
    val bufferReader = header.compressionMethod match {
      case CompressionMethod.NONE => RawInputStreamBuffer(s, bufferSize)
      case CompressionMethod.ZLIB => ZlibInputStreamBuffer(s)
      case CompressionMethod.SNAPPY => SnappyInputStreamBuffer(s)
      case _ => throw new SilkDataFormatException("Unsupported compression method (without Hadoop)")
    }

    createReader(bufferReader, header, isIPv6)
  }


  /**
    * Creates and returns a reader from the provided input stream,
    * using Hadoop compression codecs to decode compressed streams.
    *
    * @throws SilkDataFormatException if the input stream is malformed
    * or is not an IPset.
    */
  def ofInputStream(conf: Configuration, s: InputStream): IPSetReader = {
    val (header, isIPv6) = checkHeader(s)
    val bufferSize = 65536
    val bufferReader = header.compressionMethod match {
      case CompressionMethod.NONE => RawInputStreamBuffer(s, bufferSize)
      case CompressionMethod.ZLIB => ZlibInputStreamBuffer(conf, s)
      case CompressionMethod.LZO1X => LzoInputStreamBuffer(conf, s)
      case CompressionMethod.SNAPPY => SnappyInputStreamBuffer(conf, s)
      case _ => throw new SilkDataFormatException("Unrecognized compression method")
    }

    createReader(bufferReader, header, isIPv6)
  }

}

// @LICENSE_FOOTER@
//
// Copyright 2015-2022 Carnegie Mellon University. All Rights Reserved.
//
// This material is based upon work funded and supported by the
// Department of Defense and Department of Homeland Security under
// Contract No. FA8702-15-D-0002 with Carnegie Mellon University for the
// operation of the Software Engineering Institute, a federally funded
// research and development center sponsored by the United States
// Department of Defense. The U.S. Government has license rights in this
// software pursuant to DFARS 252.227.7014.
//
// NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
// INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
// UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR
// IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF
// FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS
// OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT
// MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT,
// TRADEMARK, OR COPYRIGHT INFRINGEMENT.
//
// Released under a GNU GPL 2.0-style license, please see LICENSE.txt or
// contact permission@sei.cmu.edu for full terms.
//
// [DISTRIBUTION STATEMENT A] This material has been approved for public
// release and unlimited distribution. Please see Copyright notice for
// non-US Government use and distribution.
//
// Carnegie Mellon(R) and CERT(R) are registered in the U.S. Patent and
// Trademark Office by Carnegie Mellon University.
//
// This software includes and/or makes use of third party software each
// subject to its own license as detailed in LICENSE-thirdparty.tx
//
// DM20-1143
