I'm using Celery to execute some (potentially) memory intensive tasks and I need a way to identify which tasks fail due to out-of-memory (OOM) exceptions. Specifically I'd like to record, in an external database, which tasks have failed for reporting and analysis purposes.
I have not been able to figure out what information is available both within the
task function body and in the on_failure method of the Request object that can be used
to link up what parameters passed to the Task (arg1, arg2 in the code below) caused the failure. The approach I have in mind requires there to be some sort of "key" -- a piece of information that is accessible both in the task and in the Request.on_failure method. The process would be:
- When task starts: store a record that includes the
arg1, arg2information, for the task in a database for this task; a record that can be recovered from this "key" that identifies the stat - When the task dies: use the "key" to update the record and indicate that the error occurred.
Note I do not need or want to re-try the task, I just need to record which task has died for later assessment.
I'm running the Celery workers within a Kubernetes cluster, so they have strict memory constraints.
import celery
import celery.worker.request
import logging
import numpy
logger = logging.getLogger()
class _CatchingRequest(celery.worker.request.Request):
def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
# What I want here is some way to know
# which task died
logger.error("In _LayerRequest.on_failure; pid %s", os.getpid())
logger.error("exc_info: %s [%s]", exc_info, type(exc_info))
logger.error("dir(exc_info): %s", dir(exc_info))
logger.error("traceback: %s %s", type(exc_info.traceback), exc_info.traceback)
tb = exc_info.traceback
patt = re.compile(r"SIGKILL.*Job\:\s(\d+)\.")
if m := patt.search(tb):
job = int(m.groups()[0])
logger.error("pattern matches %d", job)
return super().on_failure(exc_info, send_failed_event, return_ok)
class _CatchingTask(celery.Task):
"""
Just a stub so we can actually handle the
failure
"""
Request = _CatchingRequest
@agent.task(ignore_result=True, acks_late=True, base=_CatchingTask, bind=True)
def api_create_layer(self, arg1, arg2, timestamp):
logger.error("enter create layer; job id: %s", self.request.id)
# What I want here is some kind of value/identifier that
# that can be obtained in the on_failure call
try:
_memory_intensive(arg1, arg2, timestamp)
except Exception as err:
logger.error("Caught generic memory intensive exception")
logger.error(str(err))
# any other record keeping
def _memory_intensive(arg1, arg2, timestamp):
# just illustrative of something that could easily use a lot of memroy
arr = numpy.zeros([arg1, arg2, arg1, arg2], dtype=numpy.float64)
numpy.save(f"zeros-{arg1}-{arg2}-{timestamp}.npy", arr)