I'm trying the "defunctionalize the continuation" technique on some recursive functions, to see if I can get a good iterative version to pop out. Following along with The Best Refactoring You've Never Heard Of (in Lua just for convenience; this seems to be mostly language-agnostic), I did this:
-- original
function printTree(tree)
if tree then
printTree(tree.left)
print(tree.content)
printTree(tree.right)
end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)
-- make tail-recursive with CPS
function printTree(tree, kont)
if tree then
printTree(tree.left, function()
print(tree.content)
printTree(tree.right, kont)
end)
else
kont()
end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree, function() end)
-- defunctionalize the continuation
function apply(kont)
if kont then
print(kont.tree.content)
printTree(kont.tree.right, kont.next)
end
end
function printTree(tree, kont)
if tree then
printTree(tree.left, { tree = tree, next = kont })
else
apply(kont)
end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)
-- inline apply
function printTree(tree, kont)
if tree then
printTree(tree.left, { tree = tree, next = kont })
elseif kont then
print(kont.tree.content)
printTree(kont.tree.right, kont.next)
end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)
-- perform tail-call elimination
function printTree(tree, kont)
while true do
if tree then
kont = { tree = tree, next = kont }
tree = tree.left
elseif kont then
print(kont.tree.content)
tree = kont.tree.right
kont = kont.next
else
return
end
end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)
Then I tried the same technique on the factorial function:
-- original
function factorial(n)
if n == 0 then
return 1
else
return n * factorial(n - 1)
end
end
print(factorial(6))
-- make tail-recursive with CPS
function factorial(n, kont)
if n == 0 then
return kont(1)
else
return factorial(n - 1, function(x)
return kont(n * x)
end)
end
end
print(factorial(6, function(x) return x end))
-- defunctionalize the continuation
function apply(kont, x)
if kont then
return apply(kont.next, kont.n * x)
else
return x
end
end
function factorial(n, kont)
if n == 0 then
return apply(kont, 1)
else
return factorial(n - 1, { n = n, next = kont })
end
end
print(factorial(6))
Here's where things start to go wrong. The next step is to inline apply, but I can't do that since apply calls itself recursively. To keep going, I tried doing tail-call elimination on it.
-- perform tail-call elimination
function apply(kont, x)
while kont do
x = kont.n * x
kont = kont.next
end
return x
end
function factorial(n, kont)
if n == 0 then
return apply(kont, 1)
else
return factorial(n - 1, { n = n, next = kont })
end
end
print(factorial(6))
Okay, now we seem to be back on track.
-- inline apply
function factorial(n, kont)
if n == 0 then
local x = 1
while kont do
x = kont.n * x
kont = kont.next
end
return x
else
return factorial(n - 1, { n = n, next = kont })
end
end
print(factorial(6))
-- perform tail-call elimination
function factorial(n, kont)
while n ~= 0 do
kont = { n = n, next = kont }
n = n - 1
end
local x = 1
while kont do
x = kont.n * x
kont = kont.next
end
return x
end
print(factorial(6))
Okay, we got a fully iterative implementation of factorial, but it's a pretty bad one. I was hoping to end up with something like this instead:
function factorial(n)
local x = 1
while n ~= 0 do
x = n * x
n = n - 1
end
return x
end
print(factorial(6))
Is there any modification to the steps I followed that will let me mechanically end up with a function that looks more like this one?
First, congrats on having followed all the steps correctly. I've watched many people attempt defunctionalization exercises, and I've seen this rarely.
The 3-step process of CPS, defunctionalization, and tail-call elimination is not a standalone recipe that can produce clean code in all contexts without modification. You saw this when you needed to TCE the apply function itself. (Which, by the way, is also necessary if you want to use parentheses when printing trees, making that version much more complicated than the example I used in the blog post.)
In this case, there is an extra step needed to get the clean factorial function you desire. The function you wrote computes 5! as
5*(4*(3*(2*1))), and so the final continuation isfunction(x) return 5*(4*(3*(2*x))) end. This involves 4 multiplications, and, more generally, a continuation of this form involves an unbounded number of multiplications. This is not possible to handle without a second loop........UNLESS you note that that continuation is equivalent to
function(x) return 120*x end. The key property is the associativity of multiplication. Indeed, compare: let's say we defined a factorial analogue with a non-associative operator like exponentiation instead of multiplication. Call powerorial, such thatpowerorial(5)=5^(4^(3^(2^1))). Then it would not be possible to simplify away the loop.Here, you need to show that
apply(kont, x)is equivalent toacc*xwhereaccis equivalent to the partial factorial contained inkont. E.g.: for the defunctionalized continuationkont={n=3, next={n=4, next={n=5, next=nil}}}, show thatapply(kont, x)is equivalent to to60*x. This is a simple inductive proof using associativity.The insight that associativity in its various forms is a key step to getting clean iterative functions with accumulators applies to many problems, enough that Jeremy Gibbons wrote a paper on this topic called Continuation-Passing Style, Defunctionalization, Accumulations, and Associativity — with factorial as his first example!
This is a all part of the larger area of mechanical manipulation of programs, a topic practiced more deeply in our Advanced Software Design Course. If you want deeper than that, then you can also start reading papers on the topic of program manipulation and equational reasoning, including the works of Jeremy Gibbons, Richard Bird, Zohar Manna, and Bill Scherlis.