package pl.touk.nussknacker.ui.process.marshall

import pl.touk.nussknacker.engine.canonicalgraph.canonicalnode._
import pl.touk.nussknacker.engine.canonicalgraph.{CanonicalProcess, canonicalnode}
import pl.touk.nussknacker.engine.graph.EdgeType
import pl.touk.nussknacker.engine.graph.EdgeType.SubprocessOutput
import pl.touk.nussknacker.engine.graph.node._
import pl.touk.nussknacker.restmodel.displayedgraph.displayablenode.Edge
import pl.touk.nussknacker.restmodel.displayedgraph.{DisplayableProcess, ProcessProperties, displayablenode}
import pl.touk.nussknacker.restmodel.process.ProcessingType

object ProcessConverter {

  def toDisplayableOrDie(canonicalProcess: CanonicalProcess, processingType: ProcessingType): DisplayableProcess = {
    toDisplayable(canonicalProcess, processingType)
  }

  def toDisplayable(process: CanonicalProcess, processingType: ProcessingType): DisplayableProcess = {
    val (nodes, edges) = {
      process
        .allStartNodes.map(toGraphInner)
        .reduceLeft[(List[NodeData], List[Edge])] {
          case ((n1, e1), (n2, e2)) => (n1 ++ n2, e1 ++ e2)
        }
    }
    val props = ProcessProperties(
      typeSpecificProperties = process.metaData.typeSpecificData,
      additionalFields = process.metaData.additionalFields,
      subprocessVersions = process.metaData.subprocessVersions
    )
    DisplayableProcess(process.metaData.id, props, nodes, edges, processingType)
  }

  def findNodes(process: CanonicalProcess) : List[NodeData] = {
    process.allStartNodes.toList.flatMap(branch => toGraphInner(branch)._1)
  }

  private def toGraphInner(nodes: List[canonicalnode.CanonicalNode]): (List[NodeData], List[displayablenode.Edge]) =
    nodes match {
      case canonicalnode.FlatNode(BranchEndData(_)) :: _ => (List(), List())
      case canonicalnode.FlatNode(data) :: tail =>
        val (tailNodes, tailEdges) = toGraphInner(tail)
        (data :: tailNodes, createNextEdge(data.id, tail) ::: tailEdges)
      case canonicalnode.FilterNode(data, nextFalse) :: tail =>
        val (nextFalseNodes, nextFalseEdges) = toGraphInner(nextFalse)
        val nextFalseEdgesConnectedToFilter = createNextEdge(data.id, nextFalse, Some(EdgeType.FilterFalse)) ::: nextFalseEdges
        val (tailNodes, tailEdges) = toGraphInner(tail)
        (data :: nextFalseNodes ::: tailNodes, createNextEdge(data.id, tail, Some(EdgeType.FilterTrue)) ::: nextFalseEdgesConnectedToFilter ::: tailEdges)
      case canonicalnode.SwitchNode(data, nexts, defaultNext) :: tail =>
        val (defaultNextNodes, defaultNextEdges) = toGraphInner(defaultNext)
        val defaultNextEdgesConnectedToSwitch = createNextEdge(data.id, defaultNext, Some(EdgeType.SwitchDefault)) ::: defaultNextEdges
        val (tailNodes, tailEdges) = toGraphInner(tail)
        val (nextNodes, nextEdges) = unzipListTuple(nexts.map { c =>
          val (nextNodeNodes, nextNodeEdges) = toGraphInner(c.nodes)
          (nextNodeNodes, createNextEdge(data.id, c.nodes, Some(EdgeType.NextSwitch(c.expression))) ::: nextNodeEdges)
        })
        (data :: defaultNextNodes ::: nextNodes ::: tailNodes, createNextEdge(data.id, tail) ::: nextEdges ::: defaultNextEdgesConnectedToSwitch ::: tailEdges)
      case canonicalnode.SplitNode(data, nexts) :: tail =>
        val (tailNodes, tailEdges) = toGraphInner(tail)
        val nextInner = nexts.map(toGraphInner).unzip
        val nodes = nextInner._1.flatten
        val edges = nextInner._2.flatten
        val connecting = nexts.flatMap(createNextEdge(data.id, _, None))
        (data :: nodes ::: tailNodes, connecting ::: edges ::: tailEdges)
      case canonicalnode.Subprocess(data, outputs) :: tail =>
        val (tailNodes, tailEdges) = toGraphInner(tail)
        val nextInner = outputs.values.toList.map(toGraphInner).unzip
        val nodes = nextInner._1.flatten
        val edges = nextInner._2.flatten
        val connecting = outputs
          .flatMap{ case (name, outputEdges) => createNextEdge(data.id, outputEdges, Some(SubprocessOutput(name))) }.toList
        (data :: nodes ::: tailNodes, connecting ::: edges ::: tailEdges)
      case Nil =>
        (List(),List())
    }

  private def createNextEdge(id: String, tail: List[CanonicalNode], edgeType: Option[EdgeType] = None): List[displayablenode.Edge] = {
    tail.headOption.map {
      case FlatNode(BranchEndData(BranchEndDefinition(_, joinId))) => displayablenode.Edge(id, joinId, edgeType)
      case n => displayablenode.Edge(id, n.id, edgeType)
    }.toList
  }

  private def unzipListTuple[A, B](a: List[(List[A], List[B])]): (List[A], List[B]) = {
    val (aList, bList) = a.unzip
    (aList.flatten, bList.flatten)
  }

  def fromDisplayable(process: DisplayableProcess): CanonicalProcess = {
    val nodesMap = process.nodes.groupBy(_.id).mapValues(_.head)
    val edgesFromMapStart = process.edges.groupBy(_.from)
    val rootsUnflattened = findRootNodes(process).map(headNode => unFlattenNode(nodesMap, None)(headNode, edgesFromMapStart))
    val nodes = rootsUnflattened.headOption.getOrElse(List.empty)
    val additionalBranches = if (rootsUnflattened.isEmpty) List.empty else rootsUnflattened.tail
    CanonicalProcess(process.metaData, nodes, additionalBranches)
  }

  private def findRootNodes(process: DisplayableProcess): List[NodeData] =
    process.nodes.filter(n => n.isInstanceOf[StartingNodeData])

  private def unFlattenNode(nodesMap: Map[String, NodeData], stopAtJoin: Option[Edge])
                           (n: NodeData, edgesFromMap: Map[String, List[displayablenode.Edge]]): List[canonicalnode.CanonicalNode] = {
    def unflattenEdgeEnd(id: String, e: displayablenode.Edge): List[canonicalnode.CanonicalNode] = {
      unFlattenNode(nodesMap, Some(e))(nodesMap(e.to), edgesFromMap.updated(id, edgesFromMap(id).filterNot(_ == e)))
    }

    def getEdges(id: String): List[Edge] = edgesFromMap.getOrElse(id, List())

    val handleNestedNodes: PartialFunction[(NodeData, Option[Edge]), List[canonicalnode.CanonicalNode]] = {
      case (data: Filter, _) =>
        val filterEdges = getEdges(data.id)
        val next = filterEdges.find(_.edgeType.contains(EdgeType.FilterTrue)).map(truePath => unflattenEdgeEnd(data.id, truePath)).getOrElse(List())
        val nextFalse = filterEdges.find(_.edgeType.contains(EdgeType.FilterFalse)).map(nf => unflattenEdgeEnd(data.id, nf)).toList.flatten
        canonicalnode.FilterNode(data, nextFalse) :: next
      case (data: Switch, _) =>
        val nexts = getEdges(data.id).collect { case e@displayablenode.Edge(_, _, Some(EdgeType.NextSwitch(edgeExpr))) =>
          canonicalnode.Case(edgeExpr, unflattenEdgeEnd(data.id, e))
        }
        val default = getEdges(data.id).find(_.edgeType.contains(EdgeType.SwitchDefault)).map { e =>
          unflattenEdgeEnd(data.id, e)
        }.toList.flatten
        canonicalnode.SwitchNode(data, nexts, default) :: Nil
      case (data: Split, _) =>
        val nexts = getEdges(data.id).map(unflattenEdgeEnd(data.id, _))
        canonicalnode.SplitNode(data, nexts) :: Nil
      case (data: SubprocessInput, _) =>
        //TODO error handling?
        val nexts = getEdges(data.id).map(e => e.edgeType.get.asInstanceOf[SubprocessOutput].name -> unflattenEdgeEnd(data.id, e)).toMap
        canonicalnode.Subprocess(data, nexts) :: Nil
      case (data: Join, Some(edgeConnectedToJoin)) =>
        // We are using "from" node's id as a branchId because for now branchExpressions are inside Join nodes and it is convenient
        // way to connect both two things.
        val joinId = edgeConnectedToJoin.from
        canonicalnode.FlatNode(BranchEndData(BranchEndDefinition(joinId, data.id))) :: Nil

    }
    (handleNestedNodes orElse (handleDirectNodes andThen { n =>
      n :: getEdges(n.id).flatMap(unflattenEdgeEnd(n.id, _))
    }))((n, stopAtJoin))
  }

  private val handleDirectNodes: PartialFunction[(NodeData, Option[Edge]), canonicalnode.CanonicalNode] = {
    case (data, _) => FlatNode(data)
  }

}
