package dotty.tools.repl

import java.io.{File => JFile, PrintStream}

import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.ast.{tpd, untpd}
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Denotations.Denotation
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.Mode
import dotty.tools.dotc.core.NameKinds.SimpleNameKind
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.core.Symbols.{Symbol, defn}
import dotty.tools.dotc.interactive.Completion
import dotty.tools.dotc.printing.SyntaxHighlighting
import dotty.tools.dotc.reporting.MessageRendering
import dotty.tools.dotc.reporting.{Message, Diagnostic}
import dotty.tools.dotc.util.Spans.Span
import dotty.tools.dotc.util.{SourceFile, SourcePosition}
import dotty.tools.dotc.{CompilationUnit, Driver}
import dotty.tools.io._
import org.jline.reader._

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.util.Using

/** The state of the REPL contains necessary bindings instead of having to have
 *  mutation
 *
 *  The compiler in the REPL needs to do some wrapping in order to compile
 *  valid code. This wrapping occurs when a single `MemberDef` that cannot be
 *  top-level needs to be compiled. In order to do this, we need some unique
 *  identifier for each of these wrappers. That identifier is `objectIndex`.
 *
 *  Free expressions such as `1 + 1` needs to have an assignment in order to be
 *  of use. These expressions are therefore given a identifier on the format
 *  `resX` where `X` starts at 0 and each new expression that needs an
 *  identifier is given the increment of the old identifier. This identifier is
 *  `valIndex`.
 *
 *  @param objectIndex the index of the next wrapper
 *  @param valIndex    the index of next value binding for free expressions
 *  @param imports     a map from object index to the list of user defined imports
 *  @param context     the latest compiler context
 */
case class State(objectIndex: Int,
                 valIndex: Int,
                 imports: Map[Int, List[tpd.Import]],
                 context: Context)

/** Main REPL instance, orchestrating input, compilation and presentation */
class ReplDriver(settings: Array[String],
                 out: PrintStream = Console.out,
                 classLoader: Option[ClassLoader] = None) extends Driver {

  /** Overridden to `false` in order to not have to give sources on the
   *  commandline
   */
  override def sourcesRequired: Boolean = false

  /** Create a fresh and initialized context with IDE mode enabled */
  private def initialCtx = {
    val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive | Mode.ReadComments)
    rootCtx.setSetting(rootCtx.settings.YcookComments, true)
    val ictx = setup(settings, rootCtx)._2
    ictx.base.initialize()(ictx)
    ictx
  }

  /** the initial, empty state of the REPL session */
  final def initialState: State = State(0, 0, Map.empty, rootCtx)

  /** Reset state of repl to the initial state
   *
   *  This method is responsible for performing an all encompassing reset. As
   *  such, when the user enters `:reset` this method should be called to reset
   *  everything properly
   */
  protected def resetToInitial(): Unit = {
    rootCtx = initialCtx
    if (rootCtx.settings.outputDir.isDefault(rootCtx))
      rootCtx = rootCtx.fresh
        .setSetting(rootCtx.settings.outputDir, new VirtualDirectory("<REPL compilation output>"))
    compiler = new ReplCompiler
    rendering = new Rendering(classLoader)
  }

  private var rootCtx: Context = _
  private var compiler: ReplCompiler = _
  private var rendering: Rendering = _

  // initialize the REPL session as part of the constructor so that once `run`
  // is called, we're in business
  resetToInitial()

  /** Run REPL with `state` until `:quit` command found
   *
   *  This method is the main entry point into the REPL. Its effects are not
   *  observable outside of the CLI, for this reason, most helper methods are
   *  `protected final` to facilitate testing.
   */
  final def runUntilQuit(initialState: State = initialState): State = {
    val terminal = new JLineTerminal

    /** Blockingly read a line, getting back a parse result */
    def readLine(state: State): ParseResult = {
      val completer: Completer = { (_, line, candidates) =>
        val comps = completions(line.cursor, line.line, state)
        candidates.addAll(comps.asJava)
      }
      implicit val ctx = state.context
      try {
        val line = terminal.readLine(completer)
        ParseResult(line)(state)
      } catch {
        case _: EndOfFileException |
            _: UserInterruptException => // Ctrl+D or Ctrl+C
          Quit
      }
    }

    @tailrec def loop(state: State): State = {
      val res = readLine(state)
      if (res == Quit) state
      else loop(interpret(res)(state))
    }

    try withRedirectedOutput { loop(initialState) }
    finally terminal.close()
  }

  final def run(input: String)(implicit state: State): State = withRedirectedOutput {
    val parsed = ParseResult(input)(state)
    interpret(parsed)
  }

  // TODO: i5069
  final def bind(name: String, value: Any)(implicit state: State): State = state

  // redirecting the output allows us to test `println` in scripted tests
  private def withRedirectedOutput(op: => State): State = {
    val savedOut = System.out
    val savedErr = System.err
    try {
      System.setOut(out)
      System.setErr(out)
      op
    }
    finally {
      System.setOut(savedOut)
      System.setErr(savedErr)
    }
  }

  private def newRun(state: State) = {
    val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state)
    state.copy(context = run.runContext)
  }

  /** Extract possible completions at the index of `cursor` in `expr` */
  protected final def completions(cursor: Int, expr: String, state0: State): List[Candidate] = {
    def makeCandidate(completion: Completion) = {
      val displ = completion.label
      new Candidate(
        /* value    = */ displ,
        /* displ    = */ displ, // displayed value
        /* group    = */ null,  // can be used to group completions together
        /* descr    = */ null,  // TODO use for documentation?
        /* suffix   = */ null,
        /* key      = */ null,
        /* complete = */ false  // if true adds space when completing
      )
    }
    implicit val state = newRun(state0)
    compiler
      .typeCheck(expr, errorsAllowed = true)
      .map { tree =>
        val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
        val unit = CompilationUnit(file)(state.context)
        unit.tpdTree = tree
        implicit val ctx = state.context.fresh.setCompilationUnit(unit)
        val srcPos = SourcePosition(file, Span(cursor))
        val (_, completions) = Completion.completions(srcPos)
        completions.map(makeCandidate)
      }
      .getOrElse(Nil)
  }

  private def interpret(res: ParseResult)(implicit state: State): State = {
    val newState = res match {
      case parsed: Parsed if parsed.trees.nonEmpty =>
        compile(parsed, state)

      case SyntaxErrors(_, errs, _) =>
        displayErrors(errs)
        state

      case cmd: Command =>
        interpretCommand(cmd)

      case SigKill => // TODO
        state

      case _ => // new line, empty tree
        state
    }
    implicit val ctx: Context = newState.context
    if (!ctx.settings.XreplDisableDisplay.value)
      out.println()
    newState
  }

  /** Compile `parsed` trees and evolve `state` in accordance */
  private def compile(parsed: Parsed, istate: State): State = {
    def extractNewestWrapper(tree: untpd.Tree): Name = tree match {
      case PackageDef(_, (obj: untpd.ModuleDef) :: Nil) => obj.name.moduleClassName
      case _ => nme.NO_NAME
    }

    def extractTopLevelImports(ctx: Context): List[tpd.Import] =
      ctx.phases.collectFirst { case phase: CollectTopLevelImports => phase.imports }.get

    implicit val state = {
      val state0 = newRun(istate)
      state0.copy(context = state0.context.withSource(parsed.source))
    }
    compiler
      .compile(parsed)
      .fold(
        displayErrors,
        {
          case (unit: CompilationUnit, newState: State) =>
            val newestWrapper = extractNewestWrapper(unit.untpdTree)
            val newImports = extractTopLevelImports(newState.context)
            var allImports = newState.imports
            if (newImports.nonEmpty)
              allImports += (newState.objectIndex -> newImports)
            val newStateWithImports = newState.copy(imports = allImports)

            val warnings = newState.context.reporter
              .removeBufferedMessages(newState.context)
              .map(rendering.formatError)

            implicit val ctx: Context = newState.context
            val (updatedState, definitions) =
              if (!ctx.settings.XreplDisableDisplay.value)
                renderDefinitions(unit.tpdTree, newestWrapper)(newStateWithImports)
              else
                (newStateWithImports, Seq.empty)

            // output is printed in the order it was put in. warnings should be
            // shown before infos (eg. typedefs) for the same line. column
            // ordering is mostly to make tests deterministic
            implicit val diagnosticOrdering: Ordering[Diagnostic] =
              Ordering[(Int, Int, Int)].on(d => (d.pos.line, -d.level, d.pos.column))

            (definitions ++ warnings)
              .sorted
              .map(_.msg)
              .foreach(out.println)

            updatedState
        }
      )
  }

  private def renderDefinitions(tree: tpd.Tree, newestWrapper: Name)(implicit state: State): (State, Seq[Diagnostic]) = {
    implicit val ctx = state.context

    def resAndUnit(denot: Denotation) = {
      import scala.util.{Success, Try}
      val sym = denot.symbol
      val name = sym.name.show
      val hasValidNumber = Try(name.drop(3).toInt) match {
        case Success(num) => num < state.valIndex
        case _ => false
      }
      name.startsWith(str.REPL_RES_PREFIX) && hasValidNumber && sym.info == defn.UnitType
    }

    def extractAndFormatMembers(symbol: Symbol): (State, Seq[Diagnostic]) = if (tree.symbol.info.exists) {
      val info = symbol.info
      val defs =
        info.bounds.hi.finalResultType
          .membersBasedOnFlags(required = Method, excluded = Accessor | ParamAccessor | Synthetic | Private)
          .filterNot { denot =>
            denot.symbol.owner == defn.AnyClass ||
            denot.symbol.owner == defn.ObjectClass ||
            denot.symbol.isConstructor
          }

      val vals =
        info.fields
          .filterNot(_.symbol.isOneOf(ParamAccessor | Private | Synthetic | Artifact | Module))
          .filter(_.symbol.name.is(SimpleNameKind))

      val typeAliases =
        info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias)

      val formattedMembers =
        typeAliases.map(rendering.renderTypeAlias) ++
        defs.map(rendering.renderMethod) ++
        vals.flatMap(rendering.renderVal)

      (state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), formattedMembers)
    }
    else (state, Seq.empty)

    def isSyntheticCompanion(sym: Symbol) =
      sym.is(Module) && sym.is(Synthetic)

    def typeDefs(sym: Symbol): Seq[Diagnostic] = sym.info.memberClasses
      .collect {
        case x if !isSyntheticCompanion(x.symbol) && !x.symbol.name.isReplWrapperName =>
          rendering.renderTypeDef(x)
      }

    ctx.atPhase(ctx.typerPhase.next) {
      // Display members of wrapped module:
      tree.symbol.info.memberClasses
        .find(_.symbol.name == newestWrapper.moduleClassName)
        .map { wrapperModule =>
          val formattedTypeDefs = typeDefs(wrapperModule.symbol)
          val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol)
          val highlighted = (formattedTypeDefs ++ formattedMembers)
            .map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level))
          (newState, highlighted)
        }
        .getOrElse {
          // user defined a trait/class/object, so no module needed
          (state, Seq.empty)
        }
    }
  }

  /** Interpret `cmd` to action and propagate potentially new `state` */
  private def interpretCommand(cmd: Command)(implicit state: State): State = cmd match {
    case UnknownCommand(cmd) =>
      out.println(s"""Unknown command: "$cmd", run ":help" for a list of commands""")
      state

    case AmbiguousCommand(cmd, matching) =>
      out.println(s""""$cmd" matches ${matching.mkString(", ")}. Try typing a few more characters. Run ":help" for a list of commands""")
      state

    case Help =>
      out.println(Help.text)
      state

    case Reset =>
      resetToInitial()
      initialState

    case Imports =>
      for {
        objectIndex <- 1 to state.objectIndex
        imp <- state.imports.getOrElse(objectIndex, Nil)
      } out.println(imp.show(state.context))
      state

    case Load(path) =>
      val file = new JFile(path)
      if (file.exists) {
        val contents = Using(scala.io.Source.fromFile(file, "UTF-8"))(_.mkString).get
        run(contents)
      }
      else {
        out.println(s"""Couldn't find file "${file.getCanonicalPath}"""")
        state
      }

    case TypeOf(expr) =>
      compiler.typeOf(expr)(newRun(state)).fold(
        displayErrors,
        res => out.println(SyntaxHighlighting.highlight(res)(state.context))
      )
      state

    case DocOf(expr) =>
      compiler.docOf(expr)(newRun(state)).fold(
        displayErrors,
        res => out.println(res)
      )
      state

    case Quit =>
      // end of the world!
      state
  }

  /** shows all errors nicely formatted */
  private def displayErrors(errs: Seq[Diagnostic])(implicit state: State): State = {
    errs.map(rendering.formatError).map(_.msg).foreach(out.println)
    state
  }
}
