package epic.parser

import breeze.config.CommandLineParser
import epic.parser.GenerativeParser
import epic.trees._

/**
 * TODO
 *
 * @author dlwh
 **/
object TinyRuleTest extends App {
  val (tree,words) = Tree.fromString(
    """
      |(TOP (S
      |      (NP
      |        (NP
      |          (RB (RB A) )
      |          (PDT (PDT QQQ) )
      |          (DT (DT ZZZZ) ))
      |        (SBAR
      |         (SBAR
      |          (WHNP (WP which) )
      |          (S (VBD qqqed) ))))
      |       (VP
      |        (VBP (VBP vbes) )
      |         (NP
      |          (DT (DT the) )
      |          (NNS (NNS dogs) )))
      |      (. (. .) )))
      |
    """.stripMargin)
  assert(words.length == 9)
  var tp = new StandardTreeProcessor().apply(tree.map(AnnotatedLabel(_)))
  tp = UnaryChainCollapser.collapseUnaryChains(tp, keepChains = false)
  val ti = TreeInstance("...", tp, words)
  val ann = GenerativeParser.defaultAnnotator().apply(ti)
  println(ann.tree)
  val counts = GenerativeParser.extractCounts(IndexedSeq(ann))
  println(counts)
  val s_sbar = AnnotatedLabel("S", parents = IndexedSeq("SBAR"))
  val vbd_s = AnnotatedLabel("VBD", parents = IndexedSeq("S"))
  assert(counts._3(s_sbar, UnaryRule(s_sbar, vbd_s, IndexedSeq.empty)) == 1.0)
}

object TinyRuleTestRedux extends App {
  val (config, rawArgs) = CommandLineParser.parseArguments(args)
  val treePath = config.getProperty("treePath")
  val maxSentenceLength: Int = config.getProperty("maxLength").getOrElse("10000").toInt

  val treebank = readTreebank(treePath, maxSentenceLength)

  def readTreebank(path: Option[String], maxLength: Int): ProcessedTreebank = {
    path match {
      case Some(s: String) => ProcessedTreebank(new java.io.File(s), maxLength)
      case None => throw new RuntimeException()
    }
  }

  def runTinyRuleTest(): Unit = {
    val annotator = GenerativeParser.defaultAnnotator()
    val trainTreesRaw = treebank.trainTrees.slice(3, 4)
    val trainTrees = treebank.trainTrees.map(t => annotator(t)).slice(3, 4)

    println("=====TRAINING TREES=====")
    for (i <- 0 until trainTrees.length) {
      println("-----")
      printf("TREE: [EPIC, RAW]\n%s\n", trainTreesRaw(i).render())
      printf("TREE: [EPIC, PROC]\n%s\n", trainTrees(i).render())
      println("-----")
    }

    val counts = GenerativeParser.extractCounts(trainTrees)
    println(counts)
    val s_sbar = AnnotatedLabel("S", parents = IndexedSeq("SBAR"))
    val vbd_s = AnnotatedLabel("VBD", parents = IndexedSeq("S"))

    assert(counts._3(s_sbar, UnaryRule(s_sbar, vbd_s, IndexedSeq.empty)) > 0.5)
  }

  runTinyRuleTest()
}

