Non-recursive version of Tarjan's algorithm

3.3k Views Asked by At

I have the following (recursive) implementation of Tarjan's algorithm to find strongly connected components in a graph and it works fine:

public class StronglyConnectedComponents
{
    public static List<List<int>> Search(Graph graph)
    {
        StronglyConnectedComponents scc = new StronglyConnectedComponents();
        return scc.Tarjan(graph);
    }

    private int preCount;
    private int[] low;
    private bool[] visited;
    private Graph graph;
    private List<List<int>> stronglyConnectedComponents = new List<List<int>>();
    private Stack<int> stack = new Stack<int>();

    public List<List<int>> Tarjan(Graph graph)
    {
        this.graph = graph;
        low = new int[graph.VertexCount];
        visited = new bool[graph.VertexCount];

        for (int v = 0; v < graph.VertexCount; v++) if (!visited[v]) DFS(v);

        return stronglyConnectedComponents;
    }

    public void DFS(int v)
    {
        low[v] = preCount++;
        visited[v] = true;
        stack.Push(v);
        int min = low[v];
        int edgeCount = graph.OutgoingEdgeCount(v);
        for (int i = 0; i < edgeCount; i++)
        {
            var edge = graph.OutgoingEdge(v, i);
            int target = edge.Target;

            if (!visited[target]) DFS(target);
            if (low[target] < min) min = low[target];
        }

        if (min < low[v])
        {
            low[v] = min;
            return;
        }

        List<int> component = new List<int>();

        int w;
        do
        {
            w = stack.Pop();
            component.Add(w);
            low[w] = graph.VertexCount;
        } while (w != v);
        stronglyConnectedComponents.Add(component);
    }
}

But on large graphs, obviously, the recursive version will throw a StackOverflowException. Therefore I want to make the algorithm non-recursive.

I tried to replace the function DFS with the following (non-recursive) one, but the algorithm doesn't work anymore. Can anybody help?

private void DFS2(int vertex)
{
    bool[] visited = new bool[graph.VertexCount];
    Stack<int> stack = new Stack<int>();
    stack.Push(vertex);
    int min = low[vertex];

    while (stack.Count > 0)
    {
        int v = stack.Pop();
        if (visited[v]) continue;
        visited[v] = true;

        int edgeCount = graph.OutgoingEdgeCount(v);
        for (int i = 0; i < edgeCount; i++)
        {
            int target = graph.OutgoingEdge(v, i).Target;
            stack.Push(target);
            if (low[target] < min) min = low[target];
        }
    }

    if (min < low[vertex])
    {
        low[vertex] = min;
        return;
    }

    List<int> component = new List<int>();

    int w;
    do
    {
        w = stack.Pop();
        component.Add(w);
        low[w] = graph.VertexCount;
    } while (w != vertex);
    stronglyConnectedComponents.Add(component);
}

The following code shows the test:

public void CanFindStronglyConnectedComponents()
{
    Graph graph = new Graph(8);
    graph.AddEdge(0, 1);
    graph.AddEdge(1, 2);
    graph.AddEdge(2, 3);
    graph.AddEdge(3, 2);
    graph.AddEdge(3, 7);
    graph.AddEdge(7, 3);
    graph.AddEdge(2, 6);
    graph.AddEdge(7, 6);
    graph.AddEdge(5, 6);
    graph.AddEdge(6, 5);
    graph.AddEdge(1, 5);
    graph.AddEdge(4, 5);
    graph.AddEdge(4, 0);
    graph.AddEdge(1, 4);

    var scc = StronglyConnectedComponents.Search(graph);
    Assert.AreEqual(3, scc.Count);
    Assert.IsTrue(SetsEqual(Set(5, 6), scc[0]));
    Assert.IsTrue(SetsEqual(Set(7, 3, 2), scc[1]));
    Assert.IsTrue(SetsEqual(Set(4, 1, 0), scc[2]));
}

private IEnumerable<int> Set(params int[] set) => set;

private bool SetsEqual(IEnumerable<int> set1, IEnumerable<int> set2)
{
    if (set1.Count() != set2.Count()) return false;
    return set1.Intersect(set2).Count() == set1.Count();
}
2

There are 2 best solutions below

4
On BEST ANSWER

Here is a direct non recursive translation of the original recursive implementation (assuming it's correct):

public static List<List<int>> Search(Graph graph)
{
    var stronglyConnectedComponents = new List<List<int>>();

    int preCount = 0;
    var low = new int[graph.VertexCount];
    var visited = new bool[graph.VertexCount];
    var stack = new Stack<int>();

    var minStack = new Stack<int>();
    var enumeratorStack = new Stack<IEnumerator<int>>();
    var enumerator = Enumerable.Range(0, graph.VertexCount).GetEnumerator();
    while (true)
    {
        if (enumerator.MoveNext())
        {
            int v = enumerator.Current;
            if (!visited[v])
            {
                low[v] = preCount++;
                visited[v] = true;
                stack.Push(v);
                int min = low[v];
                // Level down
                minStack.Push(min);
                enumeratorStack.Push(enumerator);
                enumerator = Enumerable.Range(0, graph.OutgoingEdgeCount(v))
                    .Select(i => graph.OutgoingEdge(v, i).Target)
                    .GetEnumerator();
            }
            else if (minStack.Count > 0)
            {
                int min = minStack.Pop();
                if (low[v] < min) min = low[v];
                minStack.Push(min);
            }
        }
        else
        {
            // Level up
            if (enumeratorStack.Count == 0) break;

            enumerator = enumeratorStack.Pop();
            int v = enumerator.Current;
            int min = minStack.Pop();

            if (min < low[v])
            {
                low[v] = min;
            }
            else
            {
                List<int> component = new List<int>();

                int w;
                do
                {
                    w = stack.Pop();
                    component.Add(w);
                    low[w] = graph.VertexCount;
                } while (w != v);
                stronglyConnectedComponents.Add(component);
            }

            if (minStack.Count > 0)
            {
                min = minStack.Pop();
                if (low[v] < min) min = low[v];
                minStack.Push(min);
            }
        }
    }
    return stronglyConnectedComponents;
}

As usual for such direct translations, you need an explicit stack used to store the state that needs to be restored after "returning" from the recursive call. In this case, it's the level vertex enumerator and min variable.

Note that the existing stack variable cannot be used because while the processing vertex is pushed there, it's not always popped on exit (the return line in the recursive implementation), which is a specific requirement for this algorithm.

0
On

Below is a Python version I had to implement for Codeforces 427C, since they don't support increasing the Python stack size.

The code uses an extra call stack with a pointer to the current node, as well as next child to visit.

The actual algorithm follows closely the pseudocode on Wikipedia.

N = # number of vertices
es = # list of edges, [(0,1), (2,4), ...]

class Node:
    def __init__(self, name):
        self.name = name
        self.index = None
        self.lowlink = None
        self.adj = []
        self.on_stack = False

vs = [Node(i) for i in range(N)]
for v, w in es:
    vs[v].adj.append(vs[w])

i = 0
stack = []
call_stack = []
comps = []
for v in vs:
    if v.index is None:
        call_stack.append((v,0))
        while call_stack:
            v, pi = call_stack.pop()
            # If this is first time we see v
            if pi == 0:
                v.index = i
                v.lowlink = i
                i += 1
                stack.append(v)
                v.on_stack = True
            # If we just recursed on something
            if pi > 0:
                prev = v.adj[pi-1]
                v.lowlink = min(v.lowlink, prev.lowlink)
            # Find the next thing to recurse on
            while pi < len(v.adj) and v.adj[pi].index is not None:
                w = v.adj[pi]
                if w.on_stack:
                    v.lowlink = min(v.lowlink, w.index)
                pi += 1
            # If we found something with index=None, recurse
            if pi < len(v.adj):
                w = v.adj[pi]
                call_stack.append((v,pi+1))
                call_stack.append((w,0))
                continue
            # If v is the root of a connected component
            if v.lowlink == v.index:
                comp = []
                while True:
                    w = stack.pop()
                    w.on_stack = False
                    comp.append(w.name)
                    if w is v:
                        break
                comps.append(comp)

Alternatively, following the "simplified" Tarjan's algorithm, we can do the following translation:

def scc2(graph):
    result = []
    stack = []
    low = {}
    call_stack = []
    for v in graph:
        call_stack.append((v, 0, len(low)))
        while call_stack:
            v, pi, num = call_stack.pop()
            if pi == 0:
                if v in low: continue
                low[v] = num
                stack.append(v)
            if pi > 0:
                low[v] = min(low[v], low[graph[v][pi-1]])
            if pi < len(graph[v]):
                call_stack.append((v, pi+1, num))
                call_stack.append((graph[v][pi], 0, len(low)))
                continue
            if num == low[v]:
                comp = []
                while True:
                    comp.append(stack.pop())
                    low[comp[-1]] = len(graph)
                    if comp[-1] == v: break
                result.append(comp)
    return result