stop ForkJoinPool executions when an exception occurs in one task

89 Views Asked by At

We have code that walks a tree structure and performs a task for each node.

  • for a leaf node, the task is executed immediately
  • for an inner node (= node with children), the result depends on the result of the children

We use a ForkJoinPool to parallellize the work, the tasks extend RecursiveTask. The task for the root node is scheduled on the pool using pool.execute(task). The sub-tasks are scheduled using ForkJoinTask.invokeAll. (See below for full example)

We now find that if one of the tasks throws an exception, there seems to be no way to quickly propagate that exception to the parent tasks and to the pool itself to stop executing all other tasks: all other calculations are performed, the exception is thrown at the very end, after all other tasks have been done.

How can we force the ForkJoinPool to stop executing other tasks if one of the tasks fails?

Things I tried to solve the problem:

  • I had expected that passing an UncaughtExceptionHandler to the pool would do the trick, but the exception of a task doesn't even reach that handler.
  • I also experimented with pool.invoke (instead of .execute), but that doesn't help either.
  • I tried checking the cancellation state of a task using isCancelled(), at te beginning of the computation (the computeSelf method in the example below). But that flag doesn't appear to be set by the pool.

Full example, traversing a binary tree until a given depth. For each node of the tree, it accumulates how much time all children took, and adds its own timing (faked in this example using Thread.sleep) to that result.

  class MyBiTreeTask extends RecursiveTask<Long> {

    final String name;

    MyBiTreeTask(String aName) {
      name = aName;
    }

    @Override
    protected Long compute() {
      long subTiming = computeChildren();
      long selfTiming = computeSelf();
      return subTiming + selfTiming;
    }

    private long computeChildren() {
      if (name.length() >= 5) {
        return 0;
      }
      List<ForkJoinTask<Long>> subTasks = IntStream.range(0, 2)
                                                   .mapToObj(i -> new MyBiTreeTask(name + i))
                                                   .collect(Collectors.toList());
      Collection<ForkJoinTask<Long>> subResults = ForkJoinTask.invokeAll(subTasks);
      AtomicLong subTimings = new AtomicLong(0);
      subResults.forEach(r -> {
        try {
          subTimings.addAndGet(r.get());
        } catch (InterruptedException | ExecutionException aE) {
          throw new RuntimeException(aE);
        }
      });
      return subTimings.get();
    }

    private long computeSelf() {
      long t0 = System.currentTimeMillis();

      if (name.equals("r1100")) {
        throw new IllegalArgumentException("imagine something's wrong here: " + name);
      }
      try {
        // fake a task taking a long time, in another branch than the one causing the exception above
        Thread.sleep(1000 * (name.equals("r0") ? 10 : 1));
      } catch (InterruptedException aE) {
        throw new RuntimeException(aE);
      }

      long t1 = System.currentTimeMillis();
      return t1 - t0;
    }
  }

This is how the recursion is started:

  public static void main(String[] args) {
    long t0 = System.currentTimeMillis();
    try {
      ForkJoinTask<Long> root = new MyBiTreeTask("r");
      ForkJoinPool pool = new ForkJoinPool(ForkJoinPool.getCommonPoolParallelism(),
                                           ForkJoinPool.defaultForkJoinWorkerThreadFactory,
                                           (t, e) -> {
                                              throw new RuntimeException("exception in thread " + t.getId(), e);
                                           },
                                           false);
      pool.execute(root);
      Long accumulatedTiming = root.join();
      System.out.println("accumulated timing = " + accumulatedTiming);
    } catch (Exception aE) {
      System.err.println("couldn't accumulate timings");
      aE.printStackTrace();
    }
    long t1 = System.currentTimeMillis();
    long realTiming = t1 - t0;
    System.out.println("real timing = " + realTiming);
  }

Because the task will throw an exception for a leaf node, we would expect that the "realTiming" reported is rather short (only a few seconds). However, we notice that all other subtasks are performed before the exception is propagated.

EDIT added another solution I already tried

1

There are 1 best solutions below

0
On

I found a solution using parallel streams and ForkJoinTask.fork.

  • I'm not really happy with it myself, because depending on a stream being parallel or not sounds rather fragile to me.
  • I also don't really understand why it works, the ForkJoinTask javadoc seems to suggest that calling either fork on every task or invokeAll for all tasks, should have the same effect.

This is the code to replace the computeChildren method in the original example:

    private long computeChildren_parallelStream() {
      if (name.length() >= 5) {
        return 0;
      }
      return IntStream.range(0, 2)
                      .parallel()
                      .mapToObj(i -> new MyBiTreeTask(name + i))
                      .map(ForkJoinTask::fork)
                      .map(t -> {
                        try {
                          return t.get();
                        } catch (InterruptedException | ExecutionException ex) {
                          throw new RuntimeException("child of " + name + " got an exception", ex);
                        }
                      })
                      .mapToLong(l -> l)
                      .sum();
    }

With this code, an exception in a sub-task will be immediately propagated onto the root.