How to best build a persistent binary tree from a sorted stream

206 Views Asked by At

For a side project I wanted a simple way to generate a persistent binary search tree from a sorted stream. After some cursory searching I was only able to find descriptions of techniques that involved storing a sorted array where you can access any element by index. I ended up writing something that works but I figured this is well trodden territory and a canonical example is probably documented somewhere (and probably has a name).

The make shift code I made is included just for clarity. (It's also short)

object TreeFromStream {
  sealed trait ImmutableTree[T] {
    def height: Int
  }
  case class ImmutableTreeNode[T](
    value: T,
    left: ImmutableTree[T],
    right: ImmutableTree[T]
  ) extends ImmutableTree[T] {
    lazy val height = left.height + 1
  }
  case class NilTree[T]() extends ImmutableTree[T] {
    def height = 0
  }

  @tailrec
  def treeFromStream[T](
    stream: Stream[T],
    tree: ImmutableTree[T] = NilTree[T](),
    ancestors: List[ImmutableTreeNode[T]] = Nil
  ): ImmutableTree[T] = {
    (stream, ancestors) match {
      case (Stream.Empty, _) =>
        ancestors.foldLeft(tree) { case(right, root) => root.copy(right=right) }
      case (_, ancestor :: nextAncestors) if ancestor.left.height == tree.height =>
        treeFromStream(stream, ancestor.copy(right=tree), nextAncestors)
      case (next #:: rest, _) => 
        treeFromStream(
          rest, NilTree(),
          ImmutableTreeNode(next, tree, NilTree()) :: ancestors
        )
    }
  }
}
1

There are 1 best solutions below

5
Bob Dalgleish On

To create a balanced tree, which I will guess you want to do, you will need to visit each node at least once. First, collect all the nodes into a buffer, and then recursively convert the buffer into a tree:

  def tfs[T](stream: Stream[T]): ImmutableTree[T] = {
    val ss = scala.collection.mutable.ArrayBuffer.empty[T]
    def treeFromSubsequence(start: Int, end: Int): ImmutableTree[T] =
      if (end == start) NilTree()
      else if (end - start == 1) ImmutableTreeNode(ss(start), NilTree(), NilTree())
      else {
        val mid = (end - start) / 2
        ImmutableTreeNode(ss(mid), treeFromSubsequence(start, mid), treeFromSubsequence(mid + 1, end))
      }
    stream.foreach { x => ss += x }
    treeFromSubsequence(0, ss.length)
  }

It will visit each value exactly twice, once to collect it and once to put it into the value field of a tree.