Generator that is based on another generator

85 Views Asked by At

My task is actually quite simple, but I cannot figure out how to achieve it. I am intending to use this in my ML algo, but let's simplify the example. Suppose there is a generator like the following:

nums = ((i+1) for i in range(4))

The above, will yield us 1, 2, 3 and 4.

Suppose that the above generator returns individual "samples". I want to write a generator method that will batch them up. Suppose, the batch size is 2. So if this new method is called:

def batch_generator(batch_size):
    do something on nums
    yield batches of size batch_size

And then the output of this batch generator would be: 1 and 2 and then 3 and 4. Tuples/lists does not matter. What matters is to how to return these batches. I found this yield from keyword that was introduced in Python 3.3, but it seems it is not useful in my case.

And obviously, if we had 5 nums instead of 4, and batch_size is 2, we would omit the last yielded value from the first generator.

3

There are 3 best solutions below

4
BcK On BEST ANSWER

My own solution for this could be,

nums = (i+1 for i in range(4))

def giveBatch(gen, numOfItems):
    try:
        return [next(gen) for i in range(numOfItems)]
    except StopIteration:
        pass

giveBatch(nums, 2)
# [1, 2]
giveBatch(nums, 2)
# [3, 4]

Another solution is to use grouper as @Bharel mentioned. I have compared the time it takes to run both of these solutions. There is not much of a difference. I guess it can be neglected.

from timeit import timeit

def wrapper(func, *args, **kwargs):
    def wrapped():
        return func(*args, **kwargs)
    return wrapped

nums = (i+1 for i in range(1000000))

wrappedGiveBatch = wrapper(giveBatch, nums, 2)
timeit(wrappedGiveBatch, number=1000000)
# ~ 0.998439

wrappedGrouper = wrapper(grouper, nums, 2)
timeit(wrappedGrouper, number=1000000)
# ~ 0.734342
0
Bharel On

Under itertools you have a code snippet which does just that:

from itertools import zip_longest

def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

Instead of calling a method every time, you have an iterator that returns batches, much more efficient, faster, and handles corner cases like running out of data too soon without losing it.

2
Naz On

This was exactly what I needed:

def giveBatch(numOfItems):
    nums = (i+1 for i in range(7))

    while True:
        yield [next(nums) for i in range(numOfItems)]