Implementation of textbook Union-Find algorithm doesn't work

111 Views Asked by At

I have two implementations of union-find: one that I've come up on my own (it works) and another that is based on the textbook explanation (surprisingly, does not work). While I am debugging the faulty implementation as I type this, perhaps someone will be able to point out an error that so far escapes me.

I've been using the following:

https://www.cl.cam.ac.uk/teaching/1415/Algorithms/disjointsets.pdf

More can be found at the following:

https://en.wikipedia.org/wiki/Disjoint-set_data_structure

https://cp-algorithms.com/data_structures/disjoint_set_union.html

Implementation code:

from collections import Counter

def edges2vertices(edges):
    return sorted(list({vertex for edge in edges for vertex in edge}))

class UnionFindTextbook:
    def __init__(self, vertices):
        self._parents = {vertex: vertex for vertex in vertices}

    def find(self, s):
        if s == self._parents[s]:
            return s
        self._parents[s] = self.find(self._parents[s])
        return self._parents[s]

    def union(self, a, b):
        a = self.find(a)
        b = self.find(b)
        if a != b:
            self._parents[b] = a

    @property
    def partitions(self):
        return Counter(self._parents.values()).values()

class UnionFindOwn:
    def __init__(self, vertices):
        self._lookup = {v: idx for idx, v in enumerate(vertices)}
        self._forest = [[vertex] for vertex in vertices]

    def union(self, a, b):
        a = self.find(a)
        b = self.find(b)
        if self._should_merge(a, b) and self._can_merge(a, b):
            self._forest[a].extend(self._forest[b])
            for i in self._forest[b]:
                self._lookup[i] = a
            self._forest[b] = []

    def find(self, a):
        return self._lookup[a]

    def _should_merge(self, a, b):
        return a != b

    def _can_merge(self, a, b):
        return self._forest[a] and self._forest[b]

    @property
    def partitions(self):
        return [len(tree) for tree in self._forest if tree]

def unionfind_min_max(edges, strategy):
    vertices = edges2vertices(edges)
    uf = strategy(vertices)
    for a, b in edges:
        uf.union(a, b)
    return min(uf.partitions), max(uf.partitions)

Test code:

import unittest
from typing import List, Tuple

import ddt

from unionfind import unionfind_min_max, UnionFindOwn as STRATEGY

DATA_UNIONFIND_MIN_MAX = [
    {  # 1.
       # 5
       # 3-8
       # 4-9
       # 1-6-2-7
        "edges": [
            (1, 6),
            (2, 7),
            (3, 8),
            (4, 9),
            (2, 6),
        ],
        "strategy": STRATEGY,
        "expected": (2, 4),
    },
]

@ddt.ddt
class TestUnionFind(unittest.TestCase):
    @ddt.unpack
    @ddt.data(*DATA_UNIONFIND_MIN_MAX)
    def test_unionfind_min_max(self, edges: List[Tuple[int]], strategy, expected: Tuple[int]) -> None:
        actual = unionfind_min_max(edges, strategy)
        self.assertEqual(expected, actual)

Can someone point out the error in UnionFindTextbook, please?

Minimal reproducible example:

from collections import Counter

def edges2vertices(edges):
    return sorted(list({vertex for edge in edges for vertex in edge}))

class UnionFindTextbook:
    def __init__(self, vertices):
        self._parents = {vertex: vertex for vertex in vertices}

    def find(self, s):
        if s == self._parents[s]:
            return s
        self._parents[s] = self.find(self._parents[s])
        return self._parents[s]

    def union(self, a, b):
        a = self.find(a)
        b = self.find(b)
        if a != b:
            self._parents[b] = a

    @property
    def partitions(self):
        return Counter(self._parents.values()).values()

def unionfind_min_max(edges, strategy=UnionFindTextbook):
    vertices = edges2vertices(edges)
    uf = strategy(vertices)
    for a, b in edges:
        uf.union(a, b)
    return min(uf.partitions), max(uf.partitions)

edges = [
    (1, 6),
    (2, 7),
    (3, 8),
    (4, 9),
    (2, 6),
]

assert unionfind_min_max(edges) == (2, 4)
1

There are 1 best solutions below

0
On

I have managed to modify the code appropriately, but I am unsure why all of the sources do not mention that the data structure will not be in the desired state after the union() operations.

@property
def partitions(self):
    return Counter(self.find(i) for i in self._parents).values()

Any ideas?