How to use Spliterator.trySplit to do parallel computation on N cores?

429 Views Asked by At

Let's say I have a list of 10000 elements and want to process them on 6 cores. I don't want to use the existing Stream API and want to do it by myself from scratch (for the learning sake). The Spliterator interface seem to fit well for that purpose. However, it divides the collection in half whenever called. I can get 5000-5000 split, and then do one more split to get 2500-2500-2500-2500, and then 2500-2500-2500-1250-1250 to cut my initial collection into 6 parts. It seems unbalanced and no way to balance it over 6 cores.

From the Java.Doc

API Note: An ideal trySplit method efficiently (without traversal) divides its elements exactly in half, allowing balanced parallel computation.

However, the Stream .parallel() seems to solve this problem somehow. I tried to read the source, but still unable to get the gist of it. Maybe someone can explain the high level approach to me.

1

There are 1 best solutions below

0
Alexander Ivanchenko On

If you want to reimplement the functionality provided by parallel streams, then apart from the dividing the task into subtasks, you need to take care about the execution of these stacks and joining the result of the results.

Under the hood, parallel streams make use of the Fork/Join framework.

Spliterator is needed only split the data into subtasks. But the order of in which worker threads would be assigned with the task and correctness of merging the intermediate result is implemented through Fork/Join.

If you want to do it yourself, you can extend abstract class RecursiveTask and override its method compute(). That would be a "container" for your tasks (there's also RecursiveAction class which is meant to perform an action and doesn't produce a value, but the question is about computations, and we need and to obtain a result RecursiveTask is more suitable for this purpose).

To make it more flexible, you can add a property of type Function or Predicate which would be provided while instantiating it, but it would be no match for the power and flexibility of parallel streams.

While implementing compute() you need to provide the logic for splitting the task. You can use Spliterator for that and if source allow accessing random elements (like list or array) it can be done manually.

If you choose to utilize Spliterator for dividing the dataset, you can use method trySplit(), which returns Spliterator which would be null data can't be split further. Hence, if trySplit() yields null you need to process remaining elements of the current spliterator. Otherwise, you need to create a new task based on the new spliterator returned by trySplit() and apply fork() on it, and then merge the result produced by processing the remaining elements in the current spliterator with the value returned the by join() method applied on the new task.

But note with Spliterator that you'll face an issue while it comes to processing the data. Contrary to Iterator this interface doesn't declare methods that allow to access elements directly, it's not what it's meant for.

Spliterator offers only a couple of methods which allow to dial with its elements: forEachRemaining() and tryAdvance(). The first one is void, the second returns boolean value, both expect a Consumer as an argument. That means that you'll be forced to use stateful functions (which is not a good practice) in order to return a value from the compute().

want to process them on 6 cores

We can specify the required level of parallelism (maximum number of threads that would be occupied simultaneously) by using one of the parameterized constructrs of ForkJoinPool. Or alternatively, we can make use of the Java 8 factory method newWorkStealingPool() from the Executors class.

Parallel processing using Spliterator

RecursiveTask implementation:

public static class Task<T> extends RecursiveTask<T> {
    
    private Spliterator<T> spliterator;
    private BinaryOperator<T> accumulator;
    private Predicate<T> predicate = t -> true;
    private T identity;
    
    public Task(Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity) {
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
    }
    
    public Task(Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity, Predicate<T> predicate) {
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
        this.predicate = predicate;
    }
    
    @Override
    protected T compute() {
        Spliterator<T> newSpliterator = spliterator.trySplit();            
        AtomicReference<T> result = new AtomicReference<>(identity);
        
        if (newSpliterator != null) {
            Task<T> newTask = new Task<>(newSpliterator, accumulator, identity, predicate);
            newTask.fork();
            forEachRemaining(spliterator, result);
            return accumulator.apply(result.get(), newTask.join());
        }
        forEachRemaining(spliterator, result);
        return result.get();
    }
    
    private void forEachRemaining(Spliterator<T> spliterator, AtomicReference<T> result) {
        spliterator.forEachRemaining(t -> {
            if (predicate.test(t)) {
                result.set(accumulator.apply(result.get(), t));
            }
        });
    }
}

main() - Let's generate a total of all numbers in the Given Collection, and add up separately all odd and all even elements.

public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(6); // required parallelism 6
    
    Set<Integer> test = Set.of(1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0)));
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0, t -> t % 2 == 0)));
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0, t -> t % 2 != 0)));
}

Output:

45   // total of: 1, 2, 3, 4, 5, 6, 7, 8, 9
20   // total of: 2, 4, 6, 8
25   // total of: 1, 3, 5, 7, 9

Spliterator + Iterator

We can improve the approach shown above by introducing Iterator as an additional property.

That would allow making Spliterator to be responsible only for splitting the tasks, meanwhile Iterator would be used for processing the data. And it would allow to avoid using stateful functions like in the previous example.

RecursiveTask implementation:

public static class Task<T> extends RecursiveTask<T> {
    
    private Iterator<T> iterator;
    private Spliterator<T> spliterator;
    private BinaryOperator<T> accumulator;
    private Predicate<T> predicate = t -> true;
    private T identity;
    
    public Task(Iterator<T> iterator, Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity) {
        this.iterator = iterator;
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
    }
    
    public Task(Iterator<T> iterator, Spliterator<T> spliterator,
                BinaryOperator<T> accumulator, T identity, Predicate<T> predicate) {
        
        this.iterator = iterator;
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
        this.predicate = predicate;
    }
    
    @Override
    protected T compute() {
        Spliterator<T> newSpliterator = spliterator.trySplit();
        
        if (newSpliterator != null) {
            Task<T> newTask = new Task<>(iterator, newSpliterator, accumulator, identity, predicate);
            newTask.fork();
            T result = forEachRemaining(iterator);
            return accumulator.apply(result, newTask.join());
        }
        return forEachRemaining(iterator);
    }
    
    private T forEachRemaining(Iterator<T> iterator) {
        T result = identity;
        while (iterator.hasNext()) {
            T next = iterator.next();
            if (predicate.test(next)) {
                result = accumulator.apply(result, next);
            }
        }
        return result;
    }
}

main() - the same sample data.

public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(6); // required parallelism 6
    
    Set<Integer> test = Set.of(1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0)));
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0, t -> t % 2 == 0)));
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0, t -> t % 2 != 0)));
}

Output:

45   // total of: 1, 2, 3, 4, 5, 6, 7, 8, 9
20   // total of: 2, 4, 6, 8
25   // total of: 1, 3, 5, 7, 9