I'm trying to figure out a neat way of traversing a graph Scala-style, preferably with vals and immutable data types.

Given the following graph,

val graph = Map(0 -> Set(1),
                1 -> Set(2),
                2 -> Set(0, 3, 4),
                3 -> Set(),
                4 -> Set(3))

I'd like the output to be the depth first traversal starting in a given node. Starting in 1 for instance, should yield for instance 1 2 3 0 4.

I can't seem to figure out a nice way of doing this without mutable collections or vars. Any help would be appreciated.

8 Answers8


Tail Recursive solution:

  def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
    def childrenNotVisited(parent: Int, visited: List[Int]) =
      graph(parent) filter (x => !visited.contains(x))

    def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
      if (stack isEmpty) visited
      else loop(childrenNotVisited(stack.head, visited) ++ stack.tail, 
        stack.head :: visited)
    loop(Set(start), Nil) reverse
Marimuthu Madasamy
  • 1
    In `loop()`, `stack` is a `Set`, meaning it's unordered. We decide the next node to visit by taking the `head` of this unordered `stack`. How do we know this will traverse depth first if we don't know the order of the elements in the `stack`? – Adam Mackler Jul 23 '18 at 16:20
  • Why did you use `Set` for stack? You should use ordered data structure like `List`. And you'd better use `Set` instead of `List` for visited because it takes linear time for `visited.contains(x)` and use another variable for push visited order(result) – Sam Sep 21 '20 at 02:09

This is one variant I guess:

graph.foldLeft((List[Int](), 1)){
  (s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))

Updated: This is an expanded version. Here I fold left over the elements of the map starting out with a tuple of an empty list and number 1. For each element I check the size of the graph and create a new tuple accordingly. The resulting list come out in reverse order.

val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
  (s, elem) => 
    val (stack, count) = s
    if (elem._2.size == 0) 
      (0 :: stack, count) 
      (count :: stack, count + 1)
Here is recursive solution (hope I understood your requirements correctly):

def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] = 
    List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))

traverse(graph, 1)

Also please note, that this function is NOT tail recursive.

  • Right. I thought about something like this, but this doesn't include `4` in the example graph I gave. – aioobe Mar 29 '11 at 11:43
  • @aioobe: Are you sure? Here is my output: `List(1, 2, 0, 3, 4)` – tenshi Mar 29 '11 at 11:50
  • @Easy Angel, ah, sorry. Must have messed something up. That's what I get now too... – aioobe Mar 29 '11 at 12:01
  • @aioobe: What is expected output there? – tenshi Mar 29 '11 at 12:11
  • `1, 3, 0, 2` or `1, 3, 2, 0`. – aioobe Mar 29 '11 at 12:14
  • @aioobe: Ok, I thought that this graph is directed, so it should work correctly for directed graphs. – tenshi Mar 29 '11 at 12:23
  • @aioobe: Sorry, I probably understand something wrong, but if graph (in your second example) is directed, then there is no path 3 -> 2 and 3 -> 0. (but probably I solve the wrong problem: I trying to find path instead of just traversing the graph ignoring direction) – tenshi Mar 29 '11 at 12:35
  • Right. When I output 3, then 2, it is because 3 has no unvisited successors, thus I pick an arbitrary unvisited node and go from there. – aioobe Mar 29 '11 at 12:37

Don't know if you are still looking for an answer after 6 years, but here it is :)

It also returns a topological ordering and cyclicality of the graph:-

case class Node(label: Int)
    case class Graph(adj: Map[Node, Set[Node]]) {
      case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
                          isCylic: Boolean = false)

      def dfs: (List[Node], Boolean) = {
        def dfsVisit(currState: DfsState, src: Node): DfsState = {
          val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
            isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))

          val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
          finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)

        val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
        (stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
Aarsh Shah
I haven't fully understood your solution , but if I am not mistaken it's time complexity is at least O(|V|^2) since the following line complexity is O(|V|):

val newResult = result :+ node

Namely, appending an element to the right of a list.

Further more, the code is not tail recursive, which might be a problem if for example the recursion depth is limited by the environment you are using.

The following code solves a few DFS related graph problems on directed graphs. It is not the most elegant code, but if I am not mistaken it is:

  1. Tail recursive.
  2. Uses only immutable collections (and iterators on them).
  3. Has optimal time O(|V| + |E|) and space complexity (O(|V|).

The code:

import scala.annotation.tailrec
import scala.util.Try

 * Created with IntelliJ IDEA.
 * User: mishaelr
 * Date: 5/14/14
 * Time: 5:18 PM
object DirectedGraphTraversals {

  type Graph[Vertex] = Map[Vertex, Set[Vertex]]

  def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
    dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())

  def topologicalSort[Vertex](graph: Graph[Vertex]) =
    graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())

  def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
    val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
    val reversedGraph = reverse(graph)

    exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
      case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
        else {
          val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
            visitedAcc, visitedAcc,List()).toSet
          (visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)

  def reverse[Vertex](graph: Graph[Vertex]) = {
    val reverseList = for {
      (vertex, neighbours) <- graph.toList
      neighbour <- neighbours
    } yield (neighbour, vertex)


  private sealed trait NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])

  private object DfsNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
      (vertex, graph.getOrElse(vertex, Set()).iterator)

  private object TopologicalSortNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
      val neighbours = graph.getOrElse(vertex, Set())
      if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
        throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
        (vertex, neighbours.iterator)

  private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
                                                             entered: Set[Vertex], exited: Set[Vertex],
                                                             exitStack: List[Vertex]): List[Vertex] = {
    toVisit match {
      case List() => exitStack
      case (currentVertex, neighbours) :: tl =>
        val filtered = neighbours.filterNot(entered)
        if(filtered.hasNext) {
          val nextNeighbour = filtered.next()
          dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
            entered + nextNeighbour, exited, exitStack)
        } else
          dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)

  private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
                                                                  entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
    else {
      val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
      graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)

object DirectedGraphTraversalsExamples extends App {
  import DirectedGraphTraversals._

  val graph = Map(
    "B" -> Set("D", "C"),
    "A" -> Set("B", "D"),
    "D" -> Set("E"),
    "E" -> Set("C"))

  println("dfs A " +  dfs(graph, "A"))
  println("dfs B " +  dfs(graph, "B"))

  println("topologicalSort " +  topologicalSort(graph))

  println("reverse " + reverse(graph))
  println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))

  val graph2 = graph + ("C" -> Set("D"))
  println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
  println("topologicalSort graph2 " + Try(topologicalSort(graph2)))

Marimuthu Madasamy's answer is indeed working.

Here is the generic version of it:

val graph = Map(0 -> Set(1),
  1 -> Set(2),
  2 -> Set(0, 3, 4),
  3 -> Set[Int](),
  4 -> Set(3))

def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
  def childrenNotVisited(parent: T, visited: List[T]) =
    graph(parent) filter (x => !visited.contains(x))

  def loop(stack: Set[T], visited: List[T]): List[T] = {
    if (stack.isEmpty) visited
    else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
      stack.head :: visited)
  loop(Set(start), Nil).reverse


Note: You have to make sure the instances of T are correctly implementing equals and hashcode. Using case classes with primitive values is an easy way to get there.

Seems that this question is more involving than I originally thought. I wrote another recursive solution. It's still not tail recursive. I also tried hard to make it one-liner, but in this case readability will suffer a lot, so I decided to declare several vals this time:

def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
  val newResult = result :+ node
  val currentEdges = graph(node) -- newResult
  val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges

  (newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))

In my previous answer I tried to find all paths from the given node in directed graph. But it was wrong according to the requirements. This answer tries to follow directed edges, but if it can't, then it just takes some unvisited node and continues from there.

I want to revise Marimuthu Madasamy's answer because the code uses Set for stack which is unordered data structure and List for visited which takes linear time for calling contains method so that the entire time complexity is O(E * V) which is not efficient(E is # of edges and V is # of vertices). I would rather use List for stack, Set for visited(named it as discovered), and additionally use List for result value which is ordered visited nodes.

def dfs(stack: List[Int], discovered: Set[Int], orderedVisited: List[Int]): List[Int] = {
  def childrenNotVisited(start: Int) =

  if (stack.isEmpty)
  else {
    val nextNodes = childrenNotVisited(stack.head)
    dfs(nextNodes ::: stack.tail, discovered ++ nextNodes, stack.head :: orderedVisited)

val start = 0
val visitOrder = dfs(List(start), Set(start), Nil)
