Migrate a Traversable that uses a visitor to an Iterable in Scala 2.13

237 Views Asked by At

The migration guide to Scala 2.13 explains that Traversable has been removed and that Iterable should be used instead. This change is particularly annoying for one project, which is using a visitor to implement the foreach method in the Node class of a tree:

case class Node(val subnodes: Seq[Node]) extends Traversable[Node] {
  override def foreach[A](f: Node => A) = Visitor.visit(this, f)
}

object Visitor {
  def visit[A](n: Node, f: Node => A): Unit = {
    f(n)
    for (sub <- n.subnodes) {
      visit(sub, f)
    }
  }
}

object Main extends App {
  val a = Node(Seq())
  val b = Node(Seq())
  val c = Node(Seq(a, b))
  for (Node(subnodes) <- c) {
    Console.println("Visiting a node with " + subnodes.length + " subnodes")
  }
}

Output:

Visiting a node with 2 subnodes
Visiting a node with 0 subnodes
Visiting a node with 0 subnodes

An easy fix to migrate to Scala 2.13 is to first store the visited elements in a remaining buffer, which is then used to return an iterator:

import scala.collection.mutable
import scala.language.reflectiveCalls

case class Node(val subnodes: Seq[Node]) extends Iterable[Node] {
  override def iterator: Iterator[Node] = {
    val remaining = mutable.Queue.empty[Node]
    Visitor.visit(this, item => iterator.remaining.enqueue(item))
    remaining.iterator
  }
}

// Same Visitor object
// Same Main object

This solution has the disadvantages that it introduces new allocations that put pressure on the GC, because the number of visited elements is usually quite large.

Do you have suggestions on how to migrate from Traversable into Iterable, using the existing visitor but without introducing new allocations?

3

There are 3 best solutions below

5
On

As you noticed, you need to extend Iterable instead of Traversable. You can do it like this:

case class Node(name: String, subnodes: Seq[Node]) extends Iterable[Node] {
  override def iterator: Iterator[Node] = Iterator(this) ++ subnodes.flatMap(_.iterator)
}

val a = Node("a", Seq())
val b = Node("b", Seq())
val c = Node("c", Seq(a, b))
val d = Node("d", Seq(c))

for (node@Node(name, _) <- d) {
  Console.println("Visiting node " + name + " with " + node.subnodes.length + " subnodes")
}

outputs:

Visiting node d with 1 subnodes
Visiting node c with 2 subnodes
Visiting node a with 0 subnodes
Visiting node b with 0 subnodes

Then you can do more operations such as:

d.count(_.subnodes.length > 1)

Code run at Scastie.

0
On

This is an example that your code can be implemented with LazyList and that visitor is not needed:

case class Node(val subnodes: Seq[Node]) {
  
  def recursiveMap[A](f: Node => A): LazyList[A] = {
    def expand(node: Node): LazyList[Node] = node #:: LazyList.from(node.subnodes).flatMap(expand)
    expand(this).map(f)
  }
}

val a = Node(Seq())
val b = Node(Seq())
val c = Node(Seq(a, b))

val lazyList = c.recursiveMap { node =>
  println("computing value")
  "Visiting a node with " + node.subnodes.length + " subnodes"
}

println("started computing values")

lazyList.iterator.foreach(println)

output

started computing values
computing value
Visiting a node with 2 subnodes
computing value
Visiting a node with 0 subnodes
computing value
Visiting a node with 0 subnodes

If you won't store lazyList reference yourself and only iterator, then JVM would be able to GC values as you go.

0
On

We ended up writing a minimal Traversable trait, implementing just the methods that are used in our codebase. This way there is no additional overhead and the visitor's logic doesn't need to be changed.

import scala.collection.mutable

/** A trait for traversable collections. */
trait Traversable[+A] {
  self =>

  /** Applies a function to all element of the collection. */
  def foreach[B](f: A => B): Unit

  /** Creates a filter of this traversable collection. */
  def withFilter(p: A => Boolean): Traversable[A] = new WithFilter(p)

  class WithFilter(p: A => Boolean) extends Traversable[A] {
    /** Applies a function to all filtered elements of the outer collection. */
    def foreach[U](f: A => U): Unit = {
      for (x <- self) {
        if (p(x)) f(x)
      }
    }

    /** Further refines the filter of this collection. */
    override def withFilter(q: A => Boolean): WithFilter = {
      new WithFilter(x => p(x) && q(x))
    }
  }

  /** Finds the first element of this collection for which the given partial
    * function is defined, and applies the partial function to it.
    */
  def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = {
    for (x <- self) {
      if (pf.isDefinedAt(x)) {
        return Some(pf(x))
      }
    }
    None
  }

  /** Builds a new collection by applying a partial function to all elements
    * of this collection on which the function is defined.
    */
  def collect[B](pf: PartialFunction[A, B]): Iterable[B] = {
    val elements = mutable.Queue.empty[B]
    for (x <- self) {
      if (pf.isDefinedAt(x)) {
        elements.append(pf(x))
      }
    }
    elements
  }
}