Aho-Corasick implementation is taking too long

333 Views Asked by At

I'm trying to solve a problem in hackerrank which others in the discussion have said they solved using the AC algorithm. My implementation is relatively fast to build the trie and determine suffixes, but the actual matching of strings is taking a long time. Is there something I'm missing? Bisection for insertion/searching helped speed things up, but I'm not sure what else there is to improve (maybe a case of not knowing what I don't know.) I think it may have something to do with how I've created the trie, but I'm not sure. Here's my implementation:

#!/bin/python3
import sys
import bisect


class node():
    def __init__(self):
        self.child_names = [] # Nodes this node leads to
        self.child_idxs = []  # Indices of child nodes in main list
        self.gene_idxs = []   # Gene number for health value
        self.healths = []     # Health(s) of this node (may have different healths for same gene)
        self.suffix = 0       # Where to go if there is no match/look for points

    def add_child(self, child_name, child_idx):
        idx = bisect.bisect_left(self.child_names, child_name)
        self.child_names.insert(idx, child_name)
        self.child_idxs.insert(idx, child_idx)

    def add_health(self, gene_idx, health):
        idx = bisect.bisect_left(self.gene_idxs, gene_idx)
        self.gene_idxs.insert(idx, gene_idx)
        self.healths.insert(idx, health)

    def has_name(self, name):
        # Locate the leftmost value exactly equal to x
        i = bisect.bisect_left(self.child_names, name)
        if i != len(self.child_names) and self.child_names[i] == name:
            return (True, self.child_idxs[i])
        return (False, 0)


nodes = [node()]


def add_gene(gene, gene_idx, health):
    '''
    Add gene to "trie"
    '''
    new_parent_idx = 0
    for g_char in gene:
        char_is_child, idx = nodes[new_parent_idx].has_name(g_char)
        if not char_is_child:
            nodes.append(node())
            nodes[new_parent_idx].add_child(g_char, len(nodes)-1)
            new_parent_idx = len(nodes)-1
        else:
            new_parent_idx = idx
    nodes[new_parent_idx].add_health(gene_idx, health)


def get_longest_suffix():
    '''
    Get each node's longest suffix. This is where it will go to if there
    is no match, or where it will check for points
    '''
    # Breadth-first search, starting with origin's children
    next_level = nodes[0].child_idxs
    while next_level:
        new_next_level = []
        for parent_idx in next_level:
            # Look in the parent's suffix to see if there is a child match;
            # this is the fastest way of finding a child's suffix
            parent_suffix = nodes[parent_idx].suffix

            for childname, child_idx in zip(nodes[parent_idx].child_names, nodes[parent_idx].child_idxs):
                char_is_child, char_idx = nodes[parent_suffix].has_name(childname)
                if char_is_child:
                    nodes[child_idx].suffix = char_idx
                    for i in range(len(nodes[char_idx].gene_idxs)):
                        nodes[child_idx].add_health(
                            nodes[char_idx].gene_idxs[i], nodes[char_idx].healths[i])
            new_next_level += nodes[parent_idx].child_idxs
        next_level = new_next_level.copy()


def find_next_node(current_node_idx, d_char):
    '''
    Find which node to go to next based on input char. May be suffix, origin, or just skip this char
    '''
    char_is_child, child_idx = nodes[current_node_idx].has_name(d_char)
    if char_is_child:
        return child_idx
    elif (current_node_idx == 0):
        char_is_child, child_idx = nodes[0].has_name(d_char)
        if not char_is_child:
            return 0
        else:
            return child_idx
    else:
        return find_next_node(nodes[current_node_idx].suffix, d_char)


def match(d_string, first, last):
    new_node_idx = 0
    d_health = 0
    for d_char in d_string:
        new_node_idx = find_next_node(new_node_idx, d_char)
        start = bisect.bisect_left(nodes[new_node_idx].gene_idxs, first)
        stop = bisect.bisect_right(nodes[new_node_idx].gene_idxs, last)
        for i in range(start, stop):
            d_health += nodes[new_node_idx].healths[i]
    return d_health


if __name__ == '__main__':

    input_file = sys.argv[1]
    with open(input_file) as f:
        lines = f.readlines()

    n = int(lines[0])
    genes = lines[1].rstrip().split()
    health = list(map(int, lines[2].rstrip().split()))

    for gene_idx, gene in enumerate(genes):
        add_gene(gene, gene_idx, health[gene_idx])

    get_longest_suffix()

    # Calculate health of all sequences
    min_health = 10**9
    max_health = 0
    s = int(lines[3])

    for s_itr in range(s):
        firstLastd = lines[4+s_itr].split()
        first = int(firstLastd[0])
        last = int(firstLastd[1])
        d_string = firstLastd[2]

        # Start at the origin and traverse tree, adding health for matches
        d_health = match(d_string, first, last)

        if d_health < min_health:
            min_health = d_health
        if d_health > max_health:
            max_health = d_health

    print(str(min_health) + " " + str(max_health))

And a simple input file that I'm using to test (this one doesn't take long, the files that take long are like 2MB:

6
a b c aa d b
1 2 3 4 5 6
3
1 5 caaab
0 4 xyz
2 4 bcdybc

And some of the profiling output. The match function is taking a long time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    44850   12.989    0.000   16.044    0.000 dna_health.py:93(match)
  2418760    1.030    0.000    1.030    0.000 {built-in method _bisect.bisect_left}
  1600902    0.993    0.000    1.559    0.000 dna_health.py:24(has_name)
1300650/717600    0.705    0.000    1.997    0.000 dna_health.py:76(find_next_node)
   717600    0.492    0.000    0.492    0.000 {built-in method _bisect.bisect_right}
        1    0.170    0.170   16.813   16.813 dna_health.py:2(<module>)
  1601673    0.138    0.000    0.138    0.000 {built-in method builtins.len}
   100000    0.125    0.000    0.519    0.000 dna_health.py:35(add_gene)
   100000    0.063    0.000    0.127    0.000 dna_health.py:19(add_health)
1

There are 1 best solutions below

1
fdermishin On

You can try to store cumulative values of health for each node. They can be precomputed beforehand as

for node in nodes:
    node.cumulativeHealth = [0] + itertools.accumulate(node.health)

And then total health can be computed as

d_health += node.cumulativeHealth[stop] - node.cumulativeHealth[start]

With these modifications and by avoiding to compute the same value several times, function match can be rewritten like this:

def match(d_string, first, last):
    new_node_idx = 0
    d_health = 0
    for d_char in d_string:
        new_node_idx = find_next_node(new_node_idx, d_char)
        node = nodes[new_node_idx]
        gene_idxs = node.gene_idxs
        start = bisect.bisect_left(gene_idxs, first)
        stop = bisect.bisect_right(gene_idxs, last)
        d_health += node.cumulativeHealth[stop] - node.cumulativeHealth[start]
    return d_health