package dotty.tools.dotc
package transform

import core._
import Decorators._, Flags._, Types._, Contexts._, Symbols._, Constants._
import Flags._
import ast.Trees._
import ast.TreeTypeMap
import util.Positions._
import StdNames._
import ast.untpd
import tasty.TreePickler.Hole
import MegaPhase.MiniPhase
import SymUtils._
import NameKinds.OuterSelectName
import typer.Implicits.SearchFailureType

import scala.collection.mutable
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.core.quoted._


/** Translates quoted terms and types to `unpickle` method calls.
 *  Checks that the phase consistency principle (PCP) holds.
 */
class ReifyQuotes extends MacroTransformWithImplicits {
  import ast.tpd._

  override def phaseName: String = "reifyQuotes"

  override def run(implicit ctx: Context): Unit =
    if (ctx.compilationUnit.containsQuotesOrSplices) super.run

  protected def newTransformer(implicit ctx: Context): Transformer =
    new Reifier(inQuote = false, null, 0, new LevelInfo)

  private class LevelInfo {
    /** A map from locally defined symbols to the staging levels of their definitions */
    val levelOf = new mutable.HashMap[Symbol, Int]

    /** A stack of entered symbols, to be unwound after scope exit */
    var enteredSyms: List[Symbol] = Nil
  }

  /** Requiring that `paramRefs` consists of a single reference `seq` to a Seq[Any],
   *  a tree map that replaces each hole with index `n` with `seq(n)`, applied
   *  to any arguments in the hole.
   */
  private def replaceHoles(paramRefs: List[Tree]) = new TreeMap {
    val seq :: Nil = paramRefs
    override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
      case Hole(n, args) =>
        val arg =
          seq.select(nme.apply).appliedTo(Literal(Constant(n))).ensureConforms(tree.tpe)
        if (args.isEmpty) arg
        else arg.select(nme.apply).appliedTo(SeqLiteral(args, TypeTree(defn.AnyType)))
      case _ =>
        super.transform(tree)
    }
  }

  /** If `tree` has holes, convert it to a function taking a `Seq` of elements as arguments
   *  where each hole is replaced by the corresponding sequence element.
   */
  private def elimHoles(tree: Tree)(implicit ctx: Context): Tree =
    if (tree.existsSubTree(_.isInstanceOf[Hole]))
      Lambda(
        MethodType(defn.SeqType.appliedTo(defn.AnyType) :: Nil, tree.tpe),
        replaceHoles(_).transform(tree))
    else tree

  /** The main transformer class
   *  @param  inQuote    we are within a `'(...)` context that is not shadowed by a nested `~(...)`
   *  @param  outer      the next outer reifier, null is this is the topmost transformer
   *  @param  level      the current level, where quotes add one and splices subtract one level
   *  @param  levels     a stacked map from symbols to the levels in which they were defined
   */
  private class Reifier(inQuote: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo) extends ImplicitsTransformer {
    import levels._
    assert(level >= 0)

    /** A nested reifier for a quote (if `isQuote = true`) or a splice (if not) */
    def nested(isQuote: Boolean): Reifier =
      new Reifier(isQuote, this, if (isQuote) level + 1 else level - 1, levels)

    /** We are in a `~(...)` context that is not shadowed by a nested `'(...)` */
    def inSplice = outer != null && !inQuote

    /** A list of embedded quotes (if `inSplice = true`) or splices (if `inQuote = true`) */
    val embedded = new mutable.ListBuffer[Tree]

    /** A map from type ref T to expressions of type `quoted.Type[T]`".
     *  These will be turned into splices using `addTags`
     */
    val importedTags = new mutable.LinkedHashMap[TypeRef, Tree]()

    /** Assuming importedTags = `Type1 -> tag1, ..., TypeN -> tagN`, the expression
     *
     *      { type <Type1> = <tag1>.unary_~
     *        ...
     *        type <TypeN> = <tagN>.unary.~
     *        <expr>
     *      }
     *
     *  references to `TypeI` in `expr` are rewired to point to the locally
     *  defined versions. As a side effect, prepend the expressions `tag1, ..., `tagN`
     *  as splices to `embedded`.
     */
    def addTags(expr: Tree)(implicit ctx: Context): Tree =
      if (importedTags.isEmpty) expr
      else {
        val itags = importedTags.toList
        val typeDefs = for ((tref, tag) <- itags) yield {
          val rhs = transform(tag.select(tpnme.UNARY_~))
          val alias = ctx.typeAssigner.assignType(untpd.TypeBoundsTree(rhs, rhs), rhs, rhs)
          val original = tref.symbol.asType
          val local = original.copy(
            owner = ctx.owner,
            flags = Synthetic,
            info = TypeAlias(tag.tpe.select(tpnme.UNARY_~)))
          ctx.typeAssigner.assignType(untpd.TypeDef(original.name, alias), local)
        }
        importedTags.clear()
        Block(typeDefs,
          new TreeTypeMap(substFrom = itags.map(_._1.symbol), substTo = typeDefs.map(_.symbol))
            .apply(expr))
      }

    /** Enter staging level of symbol defined by `tree`, if applicable. */
    def markDef(tree: Tree)(implicit ctx: Context) = tree match {
      case tree: DefTree =>
        val sym = tree.symbol
        if ((sym.isClass || !sym.maybeOwner.isType) && !levelOf.contains(sym)) {
          levelOf(sym) = level
          enteredSyms = sym :: enteredSyms
        }
      case _ =>
    }

    /** Does the level of `sym` match the current level?
     *  An exception is made for inline vals in macros. These are also OK if their level
     *  is one higher than the current level, because on execution such values
     *  are constant expression trees and we can pull out the constant from the tree.
     */
    def levelOK(sym: Symbol)(implicit ctx: Context): Boolean = levelOf.get(sym) match {
      case Some(l) =>
        l == level ||
        sym.is(Inline) && sym.owner.is(Macro) && sym.info.isValueType && l - 1 == level
      case None =>
        true
    }

    /** Issue a "splice outside quote" error unless we ar in the body of an inline method */
    def spliceOutsideQuotes(pos: Position)(implicit ctx: Context) =
      ctx.error(i"splice outside quotes", pos)

    /** Try to heal phase-inconsistent reference to type `T` using a local type definition.
     *  @return None      if successful
     *  @return Some(msg) if unsuccessful where `msg` is a potentially empty error message
     *                    to be added to the "inconsistent phase" message.
     */
    def tryHeal(tp: Type, pos: Position)(implicit ctx: Context): Option[String] = tp match {
      case tp: TypeRef =>
        if (level == 0) {
          assert(ctx.owner.is(Macro))
          None
        } else {
          val reqType = defn.QuotedTypeType.appliedTo(tp)
          val tag = ctx.typer.inferImplicitArg(reqType, pos)
          tag.tpe match {
            case fail: SearchFailureType =>
              Some(i"""
                      |
                      | The access would be accepted with the right type tag, but
                      | ${ctx.typer.missingArgMsg(tag, reqType, "")}""")
            case _ =>
              importedTags(tp) = nested(isQuote = false).transform(tag)
              None
          }
        }
      case _ =>
        Some("")
    }

    /** Check reference to `sym` for phase consistency, where `tp` is the underlying type
     *  by which we refer to `sym`.
     */
    def check(sym: Symbol, tp: Type, pos: Position)(implicit ctx: Context): Unit = {
      val isThis = tp.isInstanceOf[ThisType]
      def symStr =
        if (!isThis) sym.show
        else if (sym.is(ModuleClass)) sym.sourceModule.show
        else i"${sym.name}.this"
      if (!isThis && sym.maybeOwner.isType)
        check(sym.owner, sym.owner.thisType, pos)
      else if (sym.exists && !sym.isStaticOwner && !levelOK(sym))
        for (errMsg <- tryHeal(tp, pos))
          ctx.error(em"""access to $symStr from wrong staging level:
                        | - the definition is at level ${levelOf(sym)},
                        | - but the access is at level $level.$errMsg""", pos)
    }

    /** Check all named types and this-types in a given type for phase consistency. */
    def checkType(pos: Position)(implicit ctx: Context): TypeAccumulator[Unit] = new TypeAccumulator[Unit] {
      def apply(acc: Unit, tp: Type): Unit = reporting.trace(i"check type level $tp at $level") {
        tp match {
          case tp: NamedType if tp.symbol.isSplice =>
            if (inQuote) outer.checkType(pos).foldOver(acc, tp)
            else {
              if (tp.isTerm) spliceOutsideQuotes(pos)
              tp
            }
          case tp: NamedType =>
            check(tp.symbol, tp, pos)
            foldOver(acc, tp)
          case tp: ThisType =>
            check(tp.cls, tp, pos)
            foldOver(acc, tp)
          case _ =>
            foldOver(acc, tp)
        }
      }
    }

    /** If `tree` refers to a locally defined symbol (either directly, or in a pickled type),
     *  check that its staging level matches the current level. References to types
     *  that are phase-incorrect can still be healed as follows:
     *
     *  If `T` is a reference to a type at the wrong level, heal it by setting things up
     *  so that we later add a type definition
     *
     *     type T' = ~quoted.Type[T]
     *
     *  to the quoted text and rename T to T' in it. This is done later in `reify` via
     *  `addTags`. `checkLevel` itself only records what needs to be done in the
     *  `typeTagOfRef` field of the current `Splice` structure.
     */
    private def checkLevel(tree: Tree)(implicit ctx: Context): Tree = {
      tree match {
        case (_: Ident) | (_: This) =>
          check(tree.symbol, tree.tpe, tree.pos)
        case (_: UnApply)  | (_: TypeTree) =>
          checkType(tree.pos).apply((), tree.tpe)
        case Select(qual, OuterSelectName(_, levels)) =>
          checkType(tree.pos).apply((), tree.tpe.widen)
        case _: Bind =>
          checkType(tree.pos).apply((), tree.symbol.info)
        case _: Template =>
          checkType(tree.pos).apply((), tree.symbol.owner.asClass.givenSelfType)
        case _ =>
      }
      tree
    }

    /** Split `body` into a core and a list of embedded splices.
     *  Then if inside a splice, make a hole from these parts.
     *  If outside a splice, generate a call tp `scala.quoted.Unpickler.unpickleType` or
     *  `scala.quoted.Unpickler.unpickleExpr` that matches `tpe` with
     *  core and splices as arguments.
     */
    private def quotation(body: Tree, quote: Tree)(implicit ctx: Context) = {
      val (body1, splices) = nested(isQuote = true).split(body)
      if (inSplice)
        makeHole(body1, splices, quote.tpe)
      else {
        def liftList(list: List[Tree], tpe: Type): Tree = {
          list.foldRight[Tree](ref(defn.NilModule)) { (x, acc) =>
            acc.select("::".toTermName).appliedToType(tpe).appliedTo(x)
          }
        }
        val isType = quote.tpe.isRef(defn.QuotedTypeClass)
        ref(if (isType) defn.Unpickler_unpickleType else defn.Unpickler_unpickleExpr)
          .appliedToType(if (isType) body1.tpe else body1.tpe.widen)
          .appliedTo(
            liftList(PickledQuotes.pickleQuote(body1).map(x => Literal(Constant(x))), defn.StringType),
            liftList(splices, defn.AnyType))
      }
    }.withPos(quote.pos)

    /** If inside a quote, split `body` into a core and a list of embedded quotes
     *  and make a hole from these parts. Otherwise issue an error, unless we
     *  are in the body of an inline method.
     */
    private def splice(body: Tree, splice: Tree)(implicit ctx: Context): Tree = {
      if (inQuote) {
        val (body1, quotes) = nested(isQuote = false).split(body)
        makeHole(body1, quotes, splice.tpe)
      }
      else {
        spliceOutsideQuotes(splice.pos)
        splice
      }
    }.withPos(splice.pos)

    /** Transform `tree` and return the resulting tree and all `embedded` quotes
     *  or splices as a pair, after performing the `addTags` transform.
     */
    private def split(tree: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
      val tree1 = addTags(transform(tree))
      (tree1, embedded.toList.map(elimHoles))
    }

    /** Register `body` as an `embedded` quote or splice
     *  and return a hole with `splices` as arguments and the given type `tpe`.
     */
    private def makeHole(body: Tree, splices: List[Tree], tpe: Type)(implicit ctx: Context): Hole = {
      val idx = embedded.length
      embedded += body
      Hole(idx, splices).withType(tpe).asInstanceOf[Hole]
    }

    override def transform(tree: Tree)(implicit ctx: Context): Tree =
      reporting.trace(i"reify $tree at $level", show = true) {
        def mapOverTree(lastEntered: List[Symbol]) =
          try super.transform(tree)
          finally
            while (enteredSyms ne lastEntered) {
              levelOf -= enteredSyms.head
              enteredSyms = enteredSyms.tail
            }
        tree match {
          case Quoted(quotedTree) =>
            quotation(quotedTree, tree)
          case Select(body, _) if tree.symbol.isSplice =>
            splice(body, tree)
          case Block(stats, _) =>
            val last = enteredSyms
            stats.foreach(markDef)
            mapOverTree(last)

          case Inlined(call, bindings, InlineSplice(expansion @ Select(body, name))) =>
            // To maintain phase consistency, we move the binding of the this parameter into the spliced code
            val (splicedBindings, stagedBindings) = bindings.partition {
              case vdef: ValDef => vdef.symbol.is(Synthetic) // Assume that only _this bindings are tagged with Synthetic
              case _ => false
            }

            val tree1 =
              if (level == 0) cpy.Inlined(tree)(call, stagedBindings, Splicer.splice(seq(splicedBindings, body)))
              else seq(stagedBindings, cpy.Select(expansion)(cpy.Inlined(tree)(call, splicedBindings, body), name))
            val tree2 = transform(tree1)

            // due to value-discarding which converts an { e } into { e; () })
            if (tree.tpe =:= defn.UnitType) Block(tree2 :: Nil, Literal(Constant(())))
            else tree2
          case _: Import =>
            tree
          case tree: DefDef if tree.symbol.is(Macro) && level == 0 =>
            markDef(tree)
            val tree1 = nested(isQuote = true).transform(tree)
              // check macro code as it if appeared in a quoted context
            cpy.DefDef(tree)(rhs = EmptyTree)
          case _ =>
            markDef(tree)
            checkLevel(mapOverTree(enteredSyms))
        }
      }

    /** InlineSplice is used to detect cases where the expansion
     *  consists of a (possibly multiple & nested) block or a sole expression.
     */
    object InlineSplice {
      def unapply(tree: Tree)(implicit ctx: Context): Option[Select] = {
        tree match {
          case expansion: Select if expansion.symbol.isSplice =>
            Some(expansion)
          case Block(List(stat), Literal(Constant(()))) => unapply(stat)
          case Block(Nil, expr) => unapply(expr)
          case _ => None
        }
      }
    }
  }
}