Python unittests: how to get an AsyncMock to block until a certain point

73 Views Asked by At

I have some async Python code I want to test, which include something like

async def task(self):
    while True:
        await self.thing.change()
        ...
        self.run_callbacks()

In my test file, I'm starting the task in the setUp() function like so:

def setUp(self):
   ...
   self.loop = asyncio.get_event_loop()
   self.task = self.loop.create_task(self.object_under_test.task())

then wanting to verify that the callbacks ran:

def test_callback(self):
    callback = mock.MagicMock()
    self.object_under_test.add_callback('cb1', callback)
    # ???
    callback.assert_called_once()

Where I'm stuck is how to mock thing.change() in the task. It needs to block until I tell it to (at ???) at which point it will return. If I mock it after starting the task, then it won't make any difference as the task is already waiting for the un-mocked function to return. If I mock it prior to starting the task, then it will always return something and not block.

Any suggestions how I might achieve this?

1

There are 1 best solutions below

0
askvictor On BEST ANSWER

I've worked out how to synchronise using the asyncio.Event object and as side effect on the mock:

# my_task.py
import asyncio

async def async_task():
    while True:
        result = await trigger()
        await asyncio.sleep(0)

async def trigger():
    # Some asynchronous function you want to mock
    await asyncio.sleep(1)
    print('triggered')
    return "Trigger Completed"
# test_my_task.py
import asyncio
import unittest
from unittest.mock import patch, AsyncMock
import my_task

class TestAsyncTask(unittest.TestCase):

    def setUp(self):
        self.trigger_event = asyncio.Event()
        async def mocked_trigger():
            await self.trigger_event.wait()
            return "Mocked Trigger Completed"
        my_task.trigger = AsyncMock(side_effect=mocked_trigger)

    def test_async_task_with_mocked_trigger(self):
        loop = asyncio.get_event_loop()

        task = loop.create_task(my_task.async_task())

        loop.run_until_complete(asyncio.sleep(1))
        self.trigger_event.set()
        loop.run_until_complete(asyncio.sleep(0))

        self.assertGreaterEqual(my_task.trigger.await_count, 1)

if __name__ == "__main__":
    unittest.main()