Haskell: Code running too slow

231 Views Asked by At

I have a code which computes a Motzkin number as:

module Main where

    -- Program execution begins here
    main :: IO ()
    main = interact (unlines . (map show) . map wave . (map read) . words)

    -- Compute Motzkin number
    wave :: Integer -> Integer
    wave 0 = 1
    wave 1 = 1
    wave n = ((3 * n - 3) * wave (n - 2) + (2 * n + 1) * wave (n - 1)) `div` (n + 2)

But the output for even a simple number as 30 takes a while to return.

Any optimization ideas??

4

There are 4 best solutions below

1
Ingo On

With n=30, you need to compute wave 29 and wave 28, which, in turn, needs to compute wave 28, wave 27 twice and wave 26 and so forth, this quickly goes in the billions.

You can employ the same trick that is used in computation of the fibonacci numbers:

wave 0 = 1
wave 1 = 1
wave n = helper 1 1 2
    where
       helper x y k | k <n      = helper y z (k+1)
                    | otherwise = z
                    where z = ((3*k-3) * x + (2*k+1) * y) `div` (k+2)

This runs in linear time, and the helper has, for every k the values for wave (k-2) and wave (k-1) ready.

0
karakfa On

here is a memoized version

wave = ((1:1:map waveCalc [2..]) !!)
    where waveCalc n = ( (2*n+1)*wave (n-1) + (3*n-3)*wave (n-2) ) `div` (n+2)
0
Daniel Wagner On

There is a standard trick for computing the Fibonacci numbers that can easily be adapted to your problem. The naive definition for Fibonacci numbers is:

fibFunction :: Int -> Integer
fibFunction 0 = 1
fibFunction 1 = 1
fibFunction n = fibFunction (n-2) + fibFunction (n-1)

However, this is very costly: since all the leaves of the recursion are 1, if fib x = y, then we must perform y recursive calls! Since the Fibonacci numbers grow exponentially, this is a bad state of affairs to be in. But with dynamic programming, we can share the computations needed in the two recursive calls. The pleasing one-liner for this looks like this:

fibList :: [Integer]
fibList = 1 : 1 : zipWith (+) fibList (tail fibList)

This may look a bit puzzling at first; here the fibList argument to zipWith serves as the recursion on two indices ago, while the tail fibList argument serves as the recursion on one index ago, which gives us both the fib (n-2) and fib (n-1) values. The two 1s at the beginning are of course the base cases. There are other good questions here on SO that explain this technique in further detail, and you should study this code and those answers until you feel you understand how it works and why it is very fast.

If necessary, one can recover the Int -> Integer type signature from this using (!!).

Let's try to apply this technique to your function. As with computing Fibonacci numbers, you need the previous and second-to-last values; and additionally need the current index. That can be done by including [2..] in the call to zipWith. Here's how it would look:

waves :: [Integer]
waves = 1 : 1 : zipWith3 thisWave [2..] waves (tail waves) where
    thisWave n back2 back1 = ((3 * n - 3) * back2 + (2 * n + 1) * back1) `div` (n + 2)

As before, one can recover the function version with (!!) or genericIndex (if one really needs Integer indices). We can confirm that it computes the same function (but faster, and using less memory) in ghci:

> :set +s
> map wave [0..30]
[1,1,2,4,9,21,51,127,323,835,2188,5798,15511,41835,113634,310572,853467,2356779,6536382,18199284,50852019,142547559,400763223,1129760415,3192727797,9043402501,25669818476,73007772802,208023278209,593742784829,1697385471211]
(6.00 secs, 3,334,097,776 bytes)
> take 31 waves
[1,1,2,4,9,21,51,127,323,835,2188,5798,15511,41835,113634,310572,853467,2356779,6536382,18199284,50852019,142547559,400763223,1129760415,3192727797,9043402501,25669818476,73007772802,208023278209,593742784829,1697385471211]
(0.00 secs, 300,696 bytes)
6
Zubin Kadva On

Thanks everyone for your responses. Based on my understanding of Memoization, I have re-written the code as:

mwave :: Int -> Int
mwave = (map wave [0..] !!)
  where wave 0 = 1
        wave 1 = 1
        wave n = ((3 * n - 3) * mwave (n - 2) + (2 * n + 1) * mwave (n - 1)) `div` (n + 2)

digits :: Int -> Int
digits n = (mwave n) `mod` 10^(100::Int)

Any thoughts on how to output the answer modulo 10^100?