Flattening nested generator expressions

7.1k Views Asked by At

I'm trying to flatten a nested generator of generators but I'm getting an unexpected result:

>>> g = ((3*i + j for j in range(3)) for i in range(3))
>>> list(itertools.chain(*g))
[6, 7, 8, 6, 7, 8, 6, 7, 8]

I expected the result to look like this:

[0, 1, 2, 3, 4, 5, 6, 7, 8]

I think I'm getting the unexpected result because the inner generators are not being evaluated until the outer generator has already been iterated over, setting i to 2. I can hack together a solution by forcing evaluation of the inner generators by using a list comprehension instead of a generator expression:

>>> g = ([3*i + j for j in range(3)] for i in range(3))
>>> list(itertools.chain(*g))
[0, 1, 2, 3, 4, 5, 6, 7, 8]

Ideally, I would like a solution that's completely lazy and doesn't force evaluation of the inner nested elements until they're used.

Is there a way to flatten nested generator expressions of arbitrary depth (maybe using something other than itertools.chain)?

Edit:

No, my question is not a duplicate of Variable Scope In Generators In Classes. I honestly can't tell how these two questions are related at all. Maybe the moderator could explain why he thinks this is a duplicate.

Also, both answers to my question are correct in that they can be used to write a function that flattens nested generators correctly.

def flattened1(iterable):
    iter1, iter2 = itertools.tee(iterable)
    if isinstance(next(iter1), collections.Iterable):
        return flattened1(x for y in iter2 for x in y)
    else:
        return iter2

def flattened2(iterable):
    iter1, iter2 = itertools.tee(iterable)
    if isinstance(next(iter1), collections.Iterable):
        return flattened2(itertools.chain.from_iterable(iter2))
    else:
        return iter2

As far as I can tell with timeit, they both perform identically.

>>> timeit(test1, setup1, number=1000000)
18.173431718023494
>>> timeit(test2, setup2, number=1000000)
17.854709611972794

I'm not sure which one is better from a style standpoint either, since x for y in iter2 for x in y is a bit of a brain twister, but arguably more elegant than itertools.chain.from_iterable(iter2). Input is appreciated.

Regrettably, I was only able to mark one of the two equally good answers correct.

3

There are 3 best solutions below

0
On BEST ANSWER

Instead of using chain(*g), you can use chain.from_iterable:

>>> g = ((3*i + j for j in range(3)) for i in range(3))
>>> list(itertools.chain(*g))
[6, 7, 8, 6, 7, 8, 6, 7, 8]
>>> g = ((3*i + j for j in range(3)) for i in range(3))
>>> list(itertools.chain.from_iterable(g))
[0, 1, 2, 3, 4, 5, 6, 7, 8]
0
On

How about this:

[x for y in g for x in y]

Which yields:

[0, 1, 2, 3, 4, 5, 6, 7, 8]
3
On

Guess you already have your answer, but here's another perspective.

The problem is that when each inner generator is created, the value-generating expression is closed over the outer variable i so even when the first inner generator starts generating values, it's using the "current" value of i. This will have value i=2 if the outer generator has been fully consumed (and that's exactly the case right after the argument in the chain(*g) call is evaluated, before chain is actually called).

The following devious trick will work around the problem:

g = ((3*i1 + j for i1 in [i] for j in range(3)) for i in range(3))

Note that these inner generators aren't closed over i because the for clauses are evaluated at generator creation time so the singleton list [i] is evaluated and its value "frozen" in the face of further changes to the value of i.

This approach has the advantage over the from_iterable answer that it's a little more general if you want to use it outside a chain.from_iterable call -- it will always produce the "correct" inner generators, whether the outer generator is partially or fully consumed before the inner generators are used. For example, in the following code:

g = ((3*i1 + j for i1 in [i] for j in range(3)) for i in range(3))
g1 = next(g)
g2 = next(g)
g3 = next(g)

you can insert the lines:

list(g1)
list(g2)
list(g3)

in any order at any point after the respective inner generator has been defined, and you'll get the correct results.