Prefect training sklearn models on large dataset gives stream closing error

61 Views Asked by At

I am using Prefect 2.11 and I want to train multiple sklearn models in parallel on large dataset on Kubernetes. However, the prefect closed the stream while creating many tasks created for training multiple models. See the below code for better understanding.

from prefect import flow, task
from dask.distributed import Client
from dask_kubernetes.operator import KubeCluster
from prefect_dask import DaskTaskRunner
from dask_ml.model_selection import RandomizedSearchCV

@task
def train_single_model(model, X, y):

    opt_model = RandomizedSearchCV(
        n_iter=20,
        cv=3,
        estimator=ParallelPostFit(model),
        param_distributions=param_grid,
        scoring='neg_mean_squared_error',
        n_jobs=-1,
        random_state=42)

    with joblib.parallel_backend("dask"):
        opt_model.fit(X, y)
    

@flow
def training_flow():

    models = [RandomForestRegressor(n_jobs=-1), XGBRegressor(n_jobs=-1)] * 10
    # Below is the prefect future from task that generates data in Pandas DataFrame
    X = PrefectFuture    # shape: (100000, 300)
    y = PrefectFuture    # shape: (100000, 1)

    train_sigle_model.map(models, X, y)
    

def main():
    # get cluster with 2 workers (256 threads each)
    cluster_specs = get_cluster_specs()
    with KubeCluster(cluster_specs) as cluster:    # Configured docker container
        cluster.scale(2)
        with Client(cluster) as client:
            training_flow.with_options(
                task_runner=DaskTaskRunner(
                    address=client.scheduler.address,
                )
            )()

if __name__ == "__main__":
    main()

I have tried reducing number of models from 30 to 6 but it failed to train them. Also, I reduced 2 workers to 1 and trained only 6 models but it failed. I suspect that the Dask creates many tasks that scheduler cannot handle them correctly or all the workers (2 in this case) have many tasks in pending and therefore some of them fail and therefore prefect marked them as unfinished task and closed the stream. Please see the log below.

10:50:04.098 | INFO    | distributed.core - Event loop was unresponsive in Worker for 5.30s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.

10:51:22.516 | ERROR   | Task run 'Training RandomForestRegressor' - Encountered exception during execution:
2023-08-28T12:51:22.520383 Traceback (most recent call last):
2023-08-28T12:51:22.520393   File "/usr/local/lib/python3.10/site-packages/prefect/engine.py", line 1730, in orchestrate_task_run
2023-08-28T12:51:22.520400     result = await call.aresult()
2023-08-28T12:51:22.520406   File "/usr/local/lib/python3.10/site-packages/prefect/_internal/concurrency/calls.py", line 292, in aresult
    return await asyncio.wrap_future(self.future)
2023-08-28T12:51:22.520421   File "/usr/local/lib/python3.10/site-packages/prefect/_internal/concurrency/calls.py", line 316, in _run_sync
2023-08-28T12:51:22.520426     result = self.fn(*self.args, **self.kwargs)
2023-08-28T12:51:22.520449   train_single_model
    opt_model.fit(X, y)
2023-08-28T12:51:22.520460   File "/usr/local/lib/python3.10/site-packages/dask_ml/model_selection/_search.py", line 1279, in fit
2023-08-28T12:51:22.520466     future.retry()
2023-08-28T12:51:22.520471   File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 393, in retry
2023-08-28T12:51:22.520475     return self.client.retry([self], **kwargs)
2023-08-28T12:51:22.520480   File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 2609, in retry
    return self.sync(self._retry, futures, asynchronous=asynchronous)
  File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 345, in sync
2023-08-28T12:51:22.520493     return sync(
2023-08-28T12:51:22.520497   File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 412, in sync
2023-08-28T12:51:22.520502     raise exc.with_traceback(tb)
2023-08-28T12:51:22.520506   File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 385, in f
    result = yield future
  File "/usr/local/lib/python3.10/site-packages/tornado/gen.py", line 767, in run
2023-08-28T12:51:22.520519     value = future.result()
2023-08-28T12:51:22.520523   File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 2593, in _retry
2023-08-28T12:51:22.520530     response = await self.scheduler.retry(keys=keys, client=self.id)
  File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 1234, in send_recv_from_rpc
2023-08-28T12:51:22.520540     return await send_recv(comm=comm, op=key, **kwargs)
2023-08-28T12:51:22.520546   File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 1018, in send_recv
2023-08-28T12:51:22.520551     raise exc.with_traceback(tb)
  File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 825, in _handle_comm
    result = handler(**msg)
  File "/usr/local/lib/python3.10/site-packages/distributed/scheduler.py", line 4726, in stimulus_retry
    ts = self.tasks[key]
2023-08-28T12:51:22.520573 KeyError: "('parallelpostfit-fit-score-1bf7c2c2b9dd10d2ce5467f71771efed', 13, 0)"

10:51:22.521 | ERROR   | Task run 'Training XGBRegressor' - Encountered exception during execution:
2023-08-28T12:51:22.527413 Traceback (most recent call last):
2023-08-28T12:51:22.527419   File "/usr/local/lib/python3.10/site-packages/prefect/engine.py", line 1730, in orchestrate_task_run
    result = await call.aresult()
2023-08-28T12:51:22.527433   File "/usr/local/lib/python3.10/site-packages/prefect/_internal/concurrency/calls.py", line 292, in aresult
    return await asyncio.wrap_future(self.future)
  File "/usr/local/lib/python3.10/site-packages/prefect/_internal/concurrency/calls.py", line 316, in _run_sync
2023-08-28T12:51:22.527457     result = self.fn(*self.args, **self.kwargs)
  train_single_model
2023-08-28T12:51:22.527476     opt_model.fit(X, y)
2023-08-28T12:51:22.527481   File "/usr/local/lib/python3.10/site-packages/dask_ml/model_selection/_search.py", line 1279, in fit
2023-08-28T12:51:22.527485     future.retry()
2023-08-28T12:51:22.527490   File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 393, in retry
2023-08-28T12:51:22.527494     return self.client.retry([self], **kwargs)
  File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 2609, in retry
    return self.sync(self._retry, futures, asynchronous=asynchronous)
2023-08-28T12:51:22.527508   File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 345, in sync
    return sync(
  File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 412, in sync
    raise exc.with_traceback(tb)
2023-08-28T12:51:22.527529   File "/usr/local/lib/python3.10/site-packages/distributed/utils.py", line 385, in f
2023-08-28T12:51:22.527534     result = yield future
  File "/usr/local/lib/python3.10/site-packages/tornado/gen.py", line 767, in run
2023-08-28T12:51:22.527542     value = future.result()
2023-08-28T12:51:22.527547   File "/usr/local/lib/python3.10/site-packages/distributed/client.py", line 2593, in _retry
2023-08-28T12:51:22.527551     response = await self.scheduler.retry(keys=keys, client=self.id)
  File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 1234, in send_recv_from_rpc
2023-08-28T12:51:22.527560     return await send_recv(comm=comm, op=key, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 1018, in send_recv
2023-08-28T12:51:22.527568     raise exc.with_traceback(tb)
2023-08-28T12:51:22.527573   File "/usr/local/lib/python3.10/site-packages/distributed/core.py", line 825, in _handle_comm
    result = handler(**msg)
  File "/usr/local/lib/python3.10/site-packages/distributed/scheduler.py", line 4726, in stimulus_retry
    ts = self.tasks[key]
2023-08-28T12:51:22.527590 KeyError: "('parallelpostfit-fit-score-a4d36eb954d2c58fc14d9b9be08fd62b', 4, 2)"

2023-08-28 10:51:24,915 - distributed.worker - INFO - Stopping worker at tcp://10.10.10.333:55555. Reason: scheduler-close
2023-08-28 10:51:24,915 - distributed._signals - INFO - Received signal SIGTERM (15)
2023-08-28T12:51:24.916682 10:51:24.915 | INFO    | distributed.worker - Stopping worker at tcp://10.10.10.333:55555. Reason: scheduler-close
2023-08-28 10:51:24,917 - distributed.nanny - INFO - Closing Nanny at 'tcp://10.10.10.333:55555'. Reason: nanny-close
2023-08-28T12:51:24.917484 2023-08-28 10:51:24,916 - distributed.core - INFO - Received 'close-stream' from tcp://dask-cluster-20230828-124957-scheduler.cluster.local:8786; closing.
2023-08-28T12:51:24.917992 10:51:24.916 | INFO    | distributed.core - Received 'close-stream' from tcp://dask-cluster-20230828-124957-scheduler.cluster.local:8786; closing.
2023-08-28 10:51:24,918 - distributed.nanny - INFO - Nanny asking worker to close. Reason: nanny-close
2023-08-28 10:51:24,921 - distributed.batched - INFO - Batched Comm Closed <TCP (closed) Worker->Scheduler local=tcp://10.10.10.333:50000 remote=tcp://dask-cluster-20230828-124957-scheduler.cluster.local:8786>
Traceback (most recent call last):
2023-08-28T12:51:24.922322   File "/usr/local/lib/python3.10/site-packages/distributed/batched.py", line 115, in _background_send
2023-08-28T12:51:24.922327     nbytes = yield coro
2023-08-28T12:51:24.922346   File "/usr/local/lib/python3.10/site-packages/tornado/gen.py", line 767, in run
2023-08-28T12:51:24.922353     value = future.result()
2023-08-28T12:51:24.922358   File "/usr/local/lib/python3.10/site-packages/distributed/comm/tcp.py", line 269, in write
    raise CommClosedError()
2023-08-28T12:51:24.922371 distributed.comm.core.CommClosedError
10:51:24.921 | INFO    | distributed.batched - Batched Comm Closed <TCP (closed) Worker->Scheduler local=tcp://10.10.10.333:50000 remote=tcp://dask-cluster-20230828-124957-scheduler.cluster.local:8786>
Traceback (most recent call last):
2023-08-28T12:51:24.923276   File "/usr/local/lib/python3.10/site-packages/distributed/batched.py", line 115, in _background_send
    nbytes = yield coro
2023-08-28T12:51:24.923285   File "/usr/local/lib/python3.10/site-packages/tornado/gen.py", line 767, in run
2023-08-28T12:51:24.923291     value = future.result()
2023-08-28T12:51:24.923296   File "/usr/local/lib/python3.10/site-packages/distributed/comm/tcp.py", line 269, in write
    raise CommClosedError()
2023-08-28T12:51:24.923306 distributed.comm.core.CommClosedError
2023-08-28 10:51:31,322 - distributed.nanny - WARNING - Worker process still alive after 6.3999821472167975 seconds, killing
2023-08-28 10:51:31,491 - distributed.nanny - INFO - Worker process 74 was killed by signal 9
2023-08-28 10:51:31,496 - distributed.dask_worker - INFO - End worker

Does anyone know, how can I achieve this orchestration without failing the tasks?

0

There are 0 best solutions below