package scala.quoted.util

import scala.quoted._

trait ExprMap {

  /** Map an expression `e` with a type `tpe` */
  def transform[T](e: Expr[T])(using qctx: QuoteContext, tpe: Type[T]): Expr[T]

  /** Map subexpressions an expression `e` with a type `tpe` */
  def transformChildren[T](e: Expr[T])(using qctx: QuoteContext, tpe: Type[T]): Expr[T] = {
    import qctx.tasty._
    final class MapChildren() {

      def transformStatement(tree: Statement)(using ctx: Context): Statement = {
        def localCtx(definition: Definition): Context = definition.symbol.localContext
        tree match {
          case tree: Term =>
            transformTerm(tree, defn.AnyType)
          case tree: Definition =>
            transformDefinition(tree)
          case tree: Import =>
            tree
        }
      }

      def transformDefinition(tree: Definition)(using ctx: Context): Definition = {
        def localCtx(definition: Definition): Context = definition.symbol.localContext
        tree match {
          case tree: ValDef =>
            given Context = localCtx(tree)
            val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe))
            ValDef.copy(tree)(tree.name, tree.tpt, rhs1)
          case tree: DefDef =>
            given Context = localCtx(tree)
            DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe)))
          case tree: TypeDef =>
            tree
          case tree: ClassDef =>
            val newBody = transformStats(tree.body)
            ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, newBody)
        }
      }

      def transformTermChildren(tree: Term, tpe: Type)(using ctx: Context): Term = tree match {
        case Ident(name) =>
          tree
        case Select(qualifier, name) =>
          Select.copy(tree)(transformTerm(qualifier, qualifier.tpe), name)
        case This(qual) =>
          tree
        case Super(qual, mix) =>
          tree
        case tree @ Apply(fun, args) =>
          val MethodType(_, tpes, _) = fun.tpe.widen
          Apply.copy(tree)(transformTerm(fun, defn.AnyType), transformTerms(args, tpes))
        case TypeApply(fun, args) =>
          TypeApply.copy(tree)(transformTerm(fun, defn.AnyType), args)
        case _: Literal =>
          tree
        case New(tpt) =>
          New.copy(tree)(transformTypeTree(tpt))
        case Typed(expr, tpt) =>
          val tp = tpt.tpe match
            // TODO improve code
            case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "<repeated>"), List(tp0: Type)) =>
              // TODO rewrite without using quotes
              type T
              val qtp: quoted.Type[T] = tp0.seal.asInstanceOf[quoted.Type[T]]
              given qtp.type = qtp
              '[Seq[T]].unseal.tpe
            case tp => tp
          Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt))
        case tree: NamedArg =>
          NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe))
        case Assign(lhs, rhs) =>
          Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen))
        case Block(stats, expr) =>
          Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe))
        case If(cond, thenp, elsep) =>
          If.copy(tree)(
            transformTerm(cond, defn.BooleanType),
            transformTerm(thenp, tpe),
            transformTerm(elsep, tpe))
        case _: Closure =>
          tree
        case Match(selector, cases) =>
          Match.copy(tree)(transformTerm(selector, selector.tpe), transformCaseDefs(cases, tpe))
        case Return(expr) =>
          // FIXME
          // ctx.owner seems to be set to the wrong symbol
          // Return.copy(tree)(transformTerm(expr, expr.tpe))
          tree
        case While(cond, body) =>
          While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType))
        case Try(block, cases, finalizer) =>
          Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType)))
        case Repeated(elems, elemtpt) =>
          Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt)
        case Inlined(call, bindings, expansion) =>
          Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
      }

      def transformTerm(tree: Term, tpe: Type)(using ctx: Context): Term =
        tree match
          case _: Closure =>
            tree
          case _: Inlined =>
            transformTermChildren(tree, tpe)
          case _ if tree.isExpr =>
            type X
            val expr = tree.seal.asInstanceOf[Expr[X]]
            val t = tpe.seal.asInstanceOf[quoted.Type[X]]
            transform(expr)(using qctx, t).unseal
          case _ =>
            transformTermChildren(tree, tpe)

      def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree

      def transformCaseDef(tree: CaseDef, tpe: Type)(using ctx: Context): CaseDef =
        CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe))

      def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = {
        TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
      }

      def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] =
        trees mapConserve (transformStatement(_))

      def transformDefinitions(trees: List[Definition])(using ctx: Context): List[Definition] =
        trees mapConserve (transformDefinition(_))

      def transformTerms(trees: List[Term], tpes: List[Type])(using ctx: Context): List[Term] =
        var tpes2 = tpes // TODO use proper zipConserve
        trees mapConserve { x =>
          val tpe :: tail = tpes2
          tpes2 = tail
          transformTerm(x, tpe)
        }

      def transformTerms(trees: List[Term], tpe: Type)(using ctx: Context): List[Term] =
        trees.mapConserve(x => transformTerm(x, tpe))

      def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] =
        trees mapConserve (transformTypeTree(_))

      def transformCaseDefs(trees: List[CaseDef], tpe: Type)(using ctx: Context): List[CaseDef] =
        trees mapConserve (x => transformCaseDef(x, tpe))

      def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] =
        trees mapConserve (transformTypeCaseDef(_))

    }
    new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).asExprOf[T]
  }

}
