How to reduce over an entire dataset in the GPU?

338 Views Asked by At

I'm using WebGPU/WGSL compute modules to analyze medium-large datasets (millions of data points) by slicing or sharding the data and making each invocation of the shader work on its own slice of data.

This works fine when data operations are local enough. But what if I need to compute the minimum value over the entire dataset, which is not a local operation (because it needs to read the entire dataset and output a single number)?

Fortunately the minimum function is commutative and associative, so I can have the shader compute the minimum of its own slice of data and put the result in a specific cell in a temporary storage. But then, how do I detect when all invocations are done and do the final reduction over that storage?

I'm thinking of putting that code at the end of my shader after a barrier and only run it once, for instance in the invocation with local id zero. But as far as I can tell, barriers only hold within a workgroup, so I would still need to reduce over the results of all workgroups. Where should that final code go?


Edit: I just thought of using a global atomic operation at the end of my shader to increment a global counter, and if I detect that I'm the last code to run it, I can reduce over the final results. Is this correct?

Something like this:

  1. perform the bulk of the work on the instance's own shard of data;
  2. store the shard-level result in a workgroup storage buffer of length workgroup_size, indexed by local_invocation_id;
  3. call workgroupBarrier to synchronize over that memory;
  4. if we are the workgroup leader (local_invocation_id is zero) reduce the workgroup buffer to a single workgroup-level result and store it into a global buffer of length num_workgroups;
  5. call storageBarrier (maybe? is this even needed here?)
  6. increment a single global counter with atomicAdd: if we are the last workgroup leader to do so, perform the last reduction over the global buffer and store the final, single result in global storage.

Can anybody confirm whether this is correct? I'm not sure whether the atomic counter can guarantee that I can read the results stored by the other workgroups. That's why I added a storageBarrier, but I'm still not sure whether it's a strong enough guarantee.

In my case (max function) I could probably replace all intermediate buffers with atomicMax calls:

  1. work on the instance data;
  2. combine the local maximum into a workgroup-level atomicMax;
  3. call workgroupBarrier() to wait for all atomicMax calls;
  4. if we are the workgroup leader, do an atomicLoad and combine it into a global atomicMax.

This should work, if I understand correctly. But what if my reduction function is not max, but is more complex? Can I use an atomic counter and intermediate buffers as I described above?

0

There are 0 best solutions below