I'm trying to flatten a nested generator of generators but I'm getting an unexpected result:
>>> g = ((3*i + j for j in range(3)) for i in range(3))
>>> list(itertools.chain(*g))
[6, 7, 8, 6, 7, 8, 6, 7, 8]
I expected the result to look like this:
[0, 1, 2, 3, 4, 5, 6, 7, 8]
I think I'm getting the unexpected result because the inner generators are not being evaluated until the outer generator has already been iterated over, setting i
to 2. I can hack together a solution by forcing evaluation of the inner generators by using a list comprehension instead of a generator expression:
>>> g = ([3*i + j for j in range(3)] for i in range(3))
>>> list(itertools.chain(*g))
[0, 1, 2, 3, 4, 5, 6, 7, 8]
Ideally, I would like a solution that's completely lazy and doesn't force evaluation of the inner nested elements until they're used.
Is there a way to flatten nested generator expressions of arbitrary depth (maybe using something other than itertools.chain
)?
Edit:
No, my question is not a duplicate of Variable Scope In Generators In Classes. I honestly can't tell how these two questions are related at all. Maybe the moderator could explain why he thinks this is a duplicate.
Also, both answers to my question are correct in that they can be used to write a function that flattens nested generators correctly.
def flattened1(iterable):
iter1, iter2 = itertools.tee(iterable)
if isinstance(next(iter1), collections.Iterable):
return flattened1(x for y in iter2 for x in y)
else:
return iter2
def flattened2(iterable):
iter1, iter2 = itertools.tee(iterable)
if isinstance(next(iter1), collections.Iterable):
return flattened2(itertools.chain.from_iterable(iter2))
else:
return iter2
As far as I can tell with timeit
, they both perform identically.
>>> timeit(test1, setup1, number=1000000)
18.173431718023494
>>> timeit(test2, setup2, number=1000000)
17.854709611972794
I'm not sure which one is better from a style standpoint either, since x for y in iter2 for x in y
is a bit of a brain twister, but arguably more elegant than itertools.chain.from_iterable(iter2)
. Input is appreciated.
Regrettably, I was only able to mark one of the two equally good answers correct.
Instead of using
chain(*g)
, you can usechain.from_iterable
: