On Haskell, there a standard function that performs "scan" on a tree?

774 Views Asked by At

I have a tree:

a :: Tree Double
a = 
    Node 1 
       [Node 20 
           [Node 300 
               [Node 400 [], 
               Node 500 []], 
           Node 310 []], 
       Node 30 [], 
       Node 40 []]

I want to apply it an scan operation similar to lists - except that, instead of returning a list, it should return a tree with the travelled paths. For example:

scan (+) 0 a

Should reduce to:

Node 1 
    [Node 21 
        [Node 321 
            [Node 721 [], 
            Node 821 []], 
        Node 331 []], 
    Node 31 [], 
    Node 41 []]

Which is accumulated the sums through the tree. Is there a standard function for this?

3

There are 3 best solutions below

3
On

If you want to pass an accumulator, then the definition is

scan f a (Node x ns) = Node a' $ map (scan f a') ns where a' = f a x

This version is also quite more efficient, compare this and this.

10
On

There is no standard library function that does this. In the case of lists: Haskell has just about any function you can think of already in Data.List, but Data.Tree is actually pretty sparse.

Fortunately the function you want is quite simple.

scan f ~(Node r l) = Node r $ map (fmap (f r) . scan f) l

--edit--

The above function has a problem: in the example given by the OP it calculates the "721" as "1 + (20 + (400 + 300))", which prevents the "1 + 20" calculation from being re-used in the other branches.

The below function doesn't have this problem, but I'm leaving the original one in place because it can still be useful depending on what function is passed as the first argument.

scan f ~(Node r l) = Node r $ map (scan' r) l where
  scan' a ~(Node n b) = let a' = f a n in Node a' $ map (scan' r) b 
5
On

update: this produces different results than requested; but it shows a valuable general approach that can be helpful where applicable, and is helpful here as well, as a counterpoint.

It is entirely possible to do this sort of traversal generically using base and GHC. The class you are looking for is Traversable and the mapAccum{R,L} functions along with fst or snd:

Lets avoid writing our own instances:

{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}

Now we can derive the necessary parts:

import Data.Traversable
import Data.Foldable

data Tree a = Node a [Tree a]
            deriving (Functor, Traversable, Foldable, Show)

Then the use is quite easy. If you don't want the final accumulator then just use snd.

main :: IO ()
main = print $ mapAccumL (\acc e -> (acc+e,acc+e)) 0 demo

demo :: Tree Int
demo =
   Node 1 [Node 20.
            [Node 300 [Node 400 []
                      , Node 500 []]
            , Node 310 []
            ]
          , Node 30 [], Node 40 []
          ]