Python Cluster connnected elements with n to m relationship

76 Views Asked by At

This is not a homework task (please see my profile). I do not have a computer science background and this question came up in an applied machine learning problem. I am pretty sure that I am not the first person to have this problem, hence I am looking for an elegant solution. I will preferre a solution using a python library over raw implementations.

Assume we have a dictionary connecting letters and numbers as input

connected = {
    'A': [1, 2, 3],
    'B': [3, 4],
    'C': [5, 6],
}

Each letter can be connected to multiple numbers. And one number can be connected to multiple letters. But each letter can only be connected to a number once.

If we look at the dictionary we realize, that the number 3 is connected with the letter 'A' and the letter 'B' hence we can put 'A' and 'B' into a cluster. The numbers of the letter 'C' are not present in the other letters. Hence, we cannot cluster the letter 'C' any further. And the expected output should be

cluster = {
    '1': {
        'letters': ['A', 'B'],
        'numbers': [1, 2, 3, 4], 
    },
    '2': {
        'letters': ['C'],
        'numbers': [5, 6],
    }
}

I think this should be related to graph algorithms and connected subgraphs but I do not know where to start.

1

There are 1 best solutions below

10
On BEST ANSWER

Using a union-find structure you can solve this efficiently in O(num letters + num numbers). The key idea is that you can just connect letters to their list of numbers. Once you do this for all letters, you automatically have unions (i.e. clusters) of desired property.

class UnionFind:
    def __init__(self):
        self.id = {}
        self.size = {}

    def find(self, a):
        cur = a
        path = []
        while self.id[cur] != cur:
            path.append(cur)
            cur = self.id[cur]
        for x in path:
            self.id[x] = cur
        return cur

    def union(self, a, b):
        if a not in self.id:
            self.id[a] = a
            self.size[a] = 1
        if b not in self.id:
            self.id[b] = b
            self.size[b] = 1

        roota, rootb = self.find(a), self.find(b)
        if roota != rootb:
            if self.size[roota] > self.size[rootb]:
                roota, rootb = rootb, roota
            self.id[roota] = rootb
            self.size[rootb] += self.size[roota]

if __name__ == "__main__":
    from collections import defaultdict

    uf = UnionFind()
    connected = {
        'A': [1, 2, 3],
        'B': [3, 4],
        'C': [5, 6],
    }
    for letter, numbers in connected.items():
        for number in numbers:
            uf.union(letter, number)
    
    clusters = defaultdict(list)
    for key, cluster_id in uf.id.items():
        clusters[cluster_id].append(key)
    
    formatted_clusters = {}
    for i, cluster_elements in enumerate(clusters.values()):
        letters = [e for e in cluster_elements if isinstance(e, str)]
        numbers = [e for e in cluster_elements if not isinstance(e, str)]
        key = str(i+1)
        formatted_clusters[key] = {
            "letters": letters,
            "numbers": numbers
        }
    print(formatted_clusters)

Output:

{'1': {'letters': ['A', 'B'], 'numbers': [1, 2, 3, 4]}, '2': {'letters': ['C'], 'numbers': [5, 6]}}