How to deal with double letters in Wagner-Fischer spellchecking

46 Views Asked by At

I tested my Wagner-Fischer distance calculator by putting in baana - one letter from banana. I got banana with 1 distance, but also cats with 3, which makes no sense, as it is 4. I have looked at the matrix, and it seems that the three 'a's in baana have somehow messed with the expected answer. Where did I go wrong?

def calculate_distance(word, destination):
    word = '_' + word
    destination = '_' + destination
    print(word, destination)
    wordlen, destinationlen = len(word), len(destination)
    distance_matrix = []
    [distance_matrix.append([None] * destinationlen) for _ in range(wordlen)]
    for i in range(len(distance_matrix)):
        distance_matrix[i][0] = i
    for i in range(len(distance_matrix[0])):
        distance_matrix[0][i] = i
    pprint(distance_matrix)
    for i, line in enumerate(distance_matrix):
        for j, d in enumerate(line):
            if d is None:
                distance_matrix[i][j] = min(distance_matrix[i-1][j-1], distance_matrix[i][j-1], distance_matrix[i-1][j])
                if word[i] != destination[j]:
                    distance_matrix[i][j] += 1

    return distance_matrix[-1][-1]
2

There are 2 best solutions below

0
TheHungryCub On

Try with :

for i in range(1, wordlen):
        for j in range(1, destinationlen):
            distance_matrix[i][j] = min(
                distance_matrix[i-1][j-1],
                distance_matrix[i][j-1],
                distance_matrix[i-1][j]
            ) + (word[i] != destination[j])

    return distance_matrix[-1][-1]
0
Sash Sinha On

The issue with your implementation lies in the way you're updating the distance_matrix. When calculating the cost of substitution, insertion, and deletion, you need to add 1 to the corresponding cell from where you're deriving the minimum value in case of insertion or deletion. However, for substitution, you should add 1 only if the characters differ, not every time like you currently do:

def calculate_distance(word, destination):
  word = '_' + word
  destination = '_' + destination
  print(word, destination)
  wordlen, destinationlen = len(word), len(destination)
  distance_matrix = [[0] * destinationlen for _ in range(wordlen)]
  for i in range(wordlen):
    distance_matrix[i][0] = i
  for j in range(destinationlen):
    distance_matrix[0][j] = j
  for i in range(1, wordlen):
    for j in range(1, destinationlen):
      distance_matrix[i][j] = min(
          distance_matrix[i - 1][j - 1] + (0 if word[i] == destination[j] else 1),  # substitution
          distance_matrix[i][j - 1] + 1,  # insertion
          distance_matrix[i - 1][j] + 1)  # deletion
  return distance_matrix[-1][-1]

print(calculate_distance('baana', 'banana'))
print(calculate_distance('baana', 'cats'))

Output:

_baana _banana
1
_baana _cats
4