// SPDX-License-Identifier: Apache-2.0

package chiseltest.simulator

import chiseltest.coverage.{Coverage, ModuleInstancesAnnotation, ModuleInstancesPass}
import firrtl2._
import firrtl2.annotations._
import firrtl2.options.Dependency
import firrtl2.passes.InlineInstances
import firrtl2.stage.{Forms, RunFirrtlTransformAnnotation}
import firrtl2.stage.TransformManager.TransformDependency
import firrtl2.transforms.EnsureNamedStatements

import scala.collection.mutable

/** Verilator generates a `coverage.dat` file with one entry for every cover statement. Unfortunately the unique name of
  * the coverage statement gets lost, however, since the (System)Verilog emitter maintains the order of the coverage
  * statements, we can just sort them by line number and compare them to the coverage statements in LoFirrtl.
  */
private object VerilatorCoverage {

  // We run these two passes in order to extract enough meta data to be able to
  // map the `coverage.dat` generated by Verilator to the cover points in the firrtl design.
  val CoveragePasses = Seq(
    RunFirrtlTransformAnnotation(Dependency(ModuleInstancesPass)),
    RunFirrtlTransformAnnotation(Dependency(FindCoverPointsPass)),
    RunFirrtlTransformAnnotation(Dependency(EnsureNamedStatements)) // without names, no cover points
  )

  // besides the common annotations, we also need to output of the FindCoverPointsPass
  def collectCoverageAnnotations(annos: AnnotationSeq): AnnotationSeq = {
    Coverage.collectCoverageAnnotations(annos) ++ annos.collect { case a: OrderedCoverPointsAnnotation => a }
  }

  def loadCoverage(annos: AnnotationSeq, coverageData: os.Path, version: (Int, Int)): List[(String, Long)] = {
    val entries = parseCoverageData(coverageData)
    verilatorCoverageToCoverageMap(entries, annos, version)
  }

  private def verilatorCoverageToCoverageMap(es: List[CoverageEntry], annos: AnnotationSeq, version: (Int, Int))
    : List[(String, Long)] = {
    // map from module name to an ordered list of cover points in said module
    val coverPoints = annos.collect { case a: OrderedCoverPointsAnnotation => a.target.module -> a.covers }.toMap
    // map from instance path name to the name of the module
    val instToModule = annos.collect { case a: ModuleInstancesAnnotation => a }.toList match {
      case List(anno) => anno.instanceToModule.toMap
      case other      => throw new RuntimeException(s"Exactly one ModuleInstancesAnnotation is required! Found: $other")
    }

    // process the coverage entries on a per instance basis
    es.groupBy(_.path).toList.flatMap { case (name, entries) =>
      // we look up the cover points by first converting to the module name
      instToModule.get(name) match {
        case Some(module) =>
          val covers = coverPoints(module)
          processInstanceCoverage(name, covers, entries)
        case None if name.contains("*") =>
          instToModule.flatMap {
            case (inst, module) if inst.matches(name.replace("*", ".*")) =>
              val covers = coverPoints(module)
              processInstanceCoverage(inst, covers, entries)
            case _ => Nil
          }
        case _ => throw new RuntimeException(s"Could not find module for instance: $name")
      }
    }
  }

  private def processInstanceCoverage(
    name:    String,
    covers:  List[String],
    entries: Seq[CoverageEntry]
  ): Seq[(String, Long)] = {
    assert(
      covers.size == entries.size,
      f"[$name] Missing or too many entries! ${covers.size} cover statements vs. ${entries.size} coverage entries.\n" +
        covers.mkString(", ") + "\n" + entries.mkString(", ")
    )
    covers.zip(entries).map { case (c, e) =>
      (if (name.isEmpty) c else name + "." + c) -> e.count
    }
  }

  private def parseCoverageData(coverageData: os.Path): List[CoverageEntry] = {
    assert(os.exists(coverageData), f"Could not find coverage file: $coverageData")
    val src = os.read.lines(coverageData)
    val entries = src.flatMap(parseLine).toList
    entries.sortBy(_.line)
  }

  // example lines:
  // "C '\x01f\x02Test1Module.sv\x01l\x0240\x01n\x020\x01page\x02v_user/Test1Module\x01o\x02cover\x01h\x02TOP.Test1Module' 3"
  // "C '\x01f\x02Test1Module.sv\x01l\x028\x01n\x020\x01page\x02v_user/SubModule1\x01o\x02cover\x01h\x02TOP.Test1Module.c0' 0"
  // "C '\x01f\x02Test1Module.sv\x01l\x028\x01n\x020\x01page\x02v_user/SubModule1\x01o\x02cover\x01h\x02TOP.Test1Module.c1' 0"
  // output:
  // - CoverageEntry(Test1Module.sv,40,List(),3)
  // - CoverageEntry(Test1Module.sv,8,List(c0),0)
  // - CoverageEntry(Test1Module.sv,8,List(c1),0)
  private def parseLine(line: String): Option[CoverageEntry] = {
    if (!line.startsWith("C '\u0001")) return None
    line.split('\'').toList match {
      case List(_, dict, countStr) =>
        val entries =
          dict.drop(1).split('\u0001').map(_.split('\u0002').toList).collect { case Seq(k, v) => k -> v }.toMap
        val count = countStr.trim.toLong
        val path = entries("h").split('.').toList.drop(2).mkString(".")
        val kind = entries("page").split("/").head
        val cov = CoverageEntry(file = entries("f"), line = entries("l").toInt, path = path, count = count)
        // filter out non-user coverage
        kind match {
          case "v_user" => Some(cov)
          case _        => None
        }
      case _ =>
        throw new RuntimeException(s"Unexpected coverage line format: $line")
    }
  }

  private case class CoverageEntry(file: String, line: Int, path: String, count: Long)
}

/** Generates a list of cover points in each module. This helps us map coverage points as reported by Verilator to the
  * standard coverage map required by the simulator backend interface.
  */
private object FindCoverPointsPass extends Transform {
  override def prerequisites: Seq[TransformDependency] = Forms.LowForm
  // we needs to run *after* any transform that changes the hierarchy or renames cover points
  override def optionalPrerequisites: Seq[TransformDependency] =
    Seq(Dependency[InlineInstances], Dependency(EnsureNamedStatements))
  // we need to run before the emitter
  override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(
    Dependency[LowFirrtlEmitter],
    Dependency[VerilogEmitter],
    Dependency[SystemVerilogEmitter]
  )
  override def invalidates(a: Transform): Boolean = false

  override protected def execute(state: CircuitState): CircuitState = {
    val c = CircuitTarget(state.circuit.main)
    val annos = state.circuit.modules.flatMap(onModule(c, _))
    state.copy(annotations = state.annotations ++ annos)
  }

  private def onModule(c: CircuitTarget, m: ir.DefModule): Option[OrderedCoverPointsAnnotation] = m match {
    case _:   ir.ExtModule => None
    case mod: ir.Module =>
      val covs = mutable.ListBuffer[String]()
      mod.foreachStmt(onStmt(_, covs))
      Some(OrderedCoverPointsAnnotation(c.module(mod.name), covs.toList))
  }

  private def onStmt(s: ir.Statement, covs: mutable.ListBuffer[String]): Unit = s match {
    case v: ir.Verification if v.op == ir.Formal.Cover =>
      assert(v.name.nonEmpty)
      covs.append(v.name)
    case other => other.foreachStmt(onStmt(_, covs))
  }
}

private case class OrderedCoverPointsAnnotation(target: ModuleTarget, covers: List[String])
    extends SingleTargetAnnotation[ModuleTarget] {
  override def duplicate(n: ModuleTarget) = copy(target = n)
}
