Tail recursive Knuth–Morris–Pratt algorithm

762 Views Asked by At

I've created a simple implementation of Knuth–Morris–Pratt algorithm in Scala. Now I want to get fancy and do the same thing in tail-recursive manner. My gut feeling says it shouldn't be too difficult (both the table and the search parts), yet that same feeling also tells me that this must have been done already by someone, probably smarter than myself. Hence the question. Do you know of any tail-recursive implementation of Knuth–Morris–Pratt algorithm?

object KnuthMorrisPrattAlgorithm {
  def search(s: String, w: String): Int = {
    if (w.isEmpty) {
      return 0
    }

    var m = 0
    var i = 0
    val t = table(w)

    while(m + i < s.length) {
      if (w(i) == s(m + i)) {
        if (i == w.length - 1) {
          return m
        }
        i += 1
      } else {
        if (t(i) > -1) {
          i = t(i)
          m += i - t(i)
        } else {
          i = 0
          m += 1
        }
      }
    }

    return -1
  }

  def table(w: String): Seq[Int] = {
    var pos = 2
    var cnd = 0
    val t = Array(-1, 0) ++ Array.fill(w.size - 2)(0)

    while (pos < w.length) {
      if (w(pos - 1) == w(cnd)) {
        cnd += 1
        t(pos) = cnd
        pos += 1
      } else if (cnd > 0) {
        cnd = t(cnd)
      } else {
        t(pos) = 0
        pos += 1
      }
    }

    t
  }
}
1

There are 1 best solutions below

7
On BEST ANSWER

I don't know what that algorithm does, but here are your functions, tail-recursivized:

object KnuthMorrisPrattAlgorithm {
  def search(s: String, w: String): Int = {
    if (w.isEmpty) {
      return 0
    }

    val t = table(w)

    def f(m: Int, i: Int): Int = {
      if (m + i < s.length) {
        if (w(i) == s(m + i)) {
          if (i == w.length - 1) {
            m
          } else {
            f(m, i + 1)
          }
        } else {
          if (t(i) > -1) {
            f(m + i - t(i), t(i))
          } else {
            f(m + 1, 0)
          }
        }
      } else {
        -1
      }
    }

    f(0, 0)
  }

  def table(w: String): Seq[Int] = {
    val t = Array(-1, 0) ++ Array.fill(w.size - 2)(0)

    def f(pos: Int, cnd: Int): Array[Int] = {
      if (pos < w.length) {
        if (w(pos - 1) == w(cnd)) {
          t(pos) = cnd + 1
          f(pos + 1, cnd + 1)
        } else if (cnd > 0) {
          f(pos, t(cnd))
        } else {
          t(pos) = 0
          f(pos + 1, cnd)
        }
      } else {
        t
      }
    }

    f(2, 0)
  }
}