Running Celery Worker (consumer) without Celery app

96 Views Asked by At

We currently have an API that creates schedules in AWS EventBridge Scheduler. When the schedule is created, we include the following payload:

{
    "ScheduleArn": "<aws.scheduler.schedule-arn>",
    "ScheduledTime": "<aws.scheduler.scheduled-time>",
    "ExecutionId": "<aws.scheduler.execution-id>",
    "AttemptNumber": "<aws.scheduler.attempt-number>",
    "ScheduleId": "03737954-943d-4d4c-882f__7573c82a-47b8-42cf-975d__8fb4a8bc",
    "Event": "validation",
    "Args": {
        "dataset_id": "7573c82a-47b8-42cf-975d-bd3a21efe45b",
        "datasource_id": "03737954-943d-4d4c-882f-5fa76394a925"
    }
}

When the task is scheduled to run, AWS EventBridge Scheduler puts the above payload onto an SQS queue. We have implemented our own worker/consumer but we are wanting to pivot to Celery as it is far more robust and performant.

Our question is, given our current implementation above where the scheduler is the one putting the task onto the queue and not the Celery app, can we setup a Celery worker/consumer to process all tasks off of the queue? We'd want a "handler" method which based on the "Event" we call a different method in our app.

If it is possible to run the Celery worker/consumer without the Celery app and if it's not too much to ask. What would the bare minimum worker/consumer code look like?

Appreciate the time and help!

For reference, this is the consumer we have setup so far using Kombu. We have multiple workers running as ECS Tasks. For this reason, the consumer should stop pulling tasks off the queue when no processes are available. Tasks can run for up to 45 minutes.

from __future__ import annotations

import json
from kombu import Exchange, Queue
from kombu.log import get_logger
from kombu.mixins import ConsumerMixin
import time
import multiprocessing.pool

logger = get_logger(__name__)

task_exchange = Exchange('tasks', type='direct')
task_queues = [Queue('job-queue', task_exchange)]


def run_dataset_validation(dataset_id: str):
    time.sleep(5)
    print(dataset_id)


class Worker(ConsumerMixin):
    def __init__(self, connection):
        self.connection = connection
        # Create a pool of worker processes
        self.pool = multiprocessing.Pool(processes=4)
        self.should_stop = False
        self.all_messages_processed = False

    def get_consumers(self, Consumer, channel):
        return [Consumer(
            queues=task_queues,
            accept=['pickle', 'json'],
            callbacks=[self.process_task]
        )]

    def _num_idle_workers(self) -> int:
        idle_workers = self.pool._processes - len(self.pool._cache)
        print(f"Total workers: {self.pool._processes}, idle workers: ", idle_workers)
        return idle_workers

    def process_task(self, body, message):
        try:
            body = json.loads(body)
            if body["Event"] == "validation":
                self.pool.apply_async(
                    func=run_dataset_validation,
                    args=(body["Args"]["dataset_id"],),
                    callback=self._on_task_complete
                )

            # Check if there are any idle worker processes in the pool
            while self._num_idle_workers() <= 0:
                if self.should_stop:
                    return
                time.sleep(2)

        except Exception as exc:
            logger.error('task raised exception: %r', exc)

        message.ack()

    def _on_task_complete(self, result):
        if self.should_stop and len(self.pool._cache) == 0:
            self.all_messages_processed = True

    def on_exit(self, exc=None):
        # Close the multiprocessing pool when the worker exits
        self.pool.close()
        self.pool.join()

    def stop(self):
        # Set should_stop attribute to True to stop consuming messages
        self.should_stop = True
        while not self.all_messages_processed:
            time.sleep(2)


if __name__ == '__main__':
    from kombu import Connection
    from kombu.utils.debug import setup_logging

    # setup root logger
    setup_logging(loglevel='INFO', loggers=[''])

    with Connection('sqs://') as conn:
        try:
            worker = Worker(conn)
            worker.run()
        except KeyboardInterrupt:
            print('Bye bye')
0

There are 0 best solutions below