We currently have an API that creates schedules in AWS EventBridge Scheduler. When the schedule is created, we include the following payload:
{
"ScheduleArn": "<aws.scheduler.schedule-arn>",
"ScheduledTime": "<aws.scheduler.scheduled-time>",
"ExecutionId": "<aws.scheduler.execution-id>",
"AttemptNumber": "<aws.scheduler.attempt-number>",
"ScheduleId": "03737954-943d-4d4c-882f__7573c82a-47b8-42cf-975d__8fb4a8bc",
"Event": "validation",
"Args": {
"dataset_id": "7573c82a-47b8-42cf-975d-bd3a21efe45b",
"datasource_id": "03737954-943d-4d4c-882f-5fa76394a925"
}
}
When the task is scheduled to run, AWS EventBridge Scheduler puts the above payload onto an SQS queue. We have implemented our own worker/consumer but we are wanting to pivot to Celery as it is far more robust and performant.
Our question is, given our current implementation above where the scheduler is the one putting the task onto the queue and not the Celery app, can we setup a Celery worker/consumer to process all tasks off of the queue? We'd want a "handler" method which based on the "Event" we call a different method in our app.
If it is possible to run the Celery worker/consumer without the Celery app and if it's not too much to ask. What would the bare minimum worker/consumer code look like?
Appreciate the time and help!
For reference, this is the consumer we have setup so far using Kombu. We have multiple workers running as ECS Tasks. For this reason, the consumer should stop pulling tasks off the queue when no processes are available. Tasks can run for up to 45 minutes.
from __future__ import annotations
import json
from kombu import Exchange, Queue
from kombu.log import get_logger
from kombu.mixins import ConsumerMixin
import time
import multiprocessing.pool
logger = get_logger(__name__)
task_exchange = Exchange('tasks', type='direct')
task_queues = [Queue('job-queue', task_exchange)]
def run_dataset_validation(dataset_id: str):
time.sleep(5)
print(dataset_id)
class Worker(ConsumerMixin):
def __init__(self, connection):
self.connection = connection
# Create a pool of worker processes
self.pool = multiprocessing.Pool(processes=4)
self.should_stop = False
self.all_messages_processed = False
def get_consumers(self, Consumer, channel):
return [Consumer(
queues=task_queues,
accept=['pickle', 'json'],
callbacks=[self.process_task]
)]
def _num_idle_workers(self) -> int:
idle_workers = self.pool._processes - len(self.pool._cache)
print(f"Total workers: {self.pool._processes}, idle workers: ", idle_workers)
return idle_workers
def process_task(self, body, message):
try:
body = json.loads(body)
if body["Event"] == "validation":
self.pool.apply_async(
func=run_dataset_validation,
args=(body["Args"]["dataset_id"],),
callback=self._on_task_complete
)
# Check if there are any idle worker processes in the pool
while self._num_idle_workers() <= 0:
if self.should_stop:
return
time.sleep(2)
except Exception as exc:
logger.error('task raised exception: %r', exc)
message.ack()
def _on_task_complete(self, result):
if self.should_stop and len(self.pool._cache) == 0:
self.all_messages_processed = True
def on_exit(self, exc=None):
# Close the multiprocessing pool when the worker exits
self.pool.close()
self.pool.join()
def stop(self):
# Set should_stop attribute to True to stop consuming messages
self.should_stop = True
while not self.all_messages_processed:
time.sleep(2)
if __name__ == '__main__':
from kombu import Connection
from kombu.utils.debug import setup_logging
# setup root logger
setup_logging(loglevel='INFO', loggers=[''])
with Connection('sqs://') as conn:
try:
worker = Worker(conn)
worker.run()
except KeyboardInterrupt:
print('Bye bye')