Compute per-warp histogram without shared memory

146 Views Asked by At

Problem Compute a per-warp histogram of sorted sequence of numbers held by individual threads in a warp.

Example:

lane: 0123456789...          31
val:  222244455777799999 ..

The result must be held by N lower threads in a warp (where N is the amount of unique numbers), e.g.:

lane 0: val=2, num=4 (2 occurs 4 times)
lane 1: val=4, num=3 (4 occurs 3 times)
lane 2: val=5, num=2 ...
lane 3: val=7, num=4
lane 4: val=9, num=5
...

Note that, it is essentially not required for a sequence of 'val' to be sorted: it's only necessary for equal numbers to be grouped together, i.e.: 99955555773333333...

Possible solution This can be done quite efficiently with shuffle intrinsics, though my question is whether it's possible to do this without using shared memory at all (I mean shared memory is a scarce resource, I need it somewhere else) ?

For simplicity, I execute this code for a single warp only (so that printf works fine):

__device__ __inline__ void sorted_seq_histogram()
{
    uint32_t tid = threadIdx.x, lane = tid % 32;
    uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced

    printf("%d: val = %d\n", lane, val);
    uint32_t num = 1;

    uint32_t allmsk = 0xffffffffu, shfl_c = 31;
    for(int i = 1; i <= 16; i *= 2) {

#if 1
        uint32_t xval = __shfl_down_sync(allmsk, val, i),
                 xnum = __shfl_down_sync(allmsk, num, i);
        if(lane + i < 32) {
            if(val == xval)
                num += xnum;
        }
#else  // this is a (hopefully) optimized version of the code above
        asm(R"({
          .reg .u32 r0,r1;
          .reg .pred p;
          shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
          shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
          @p setp.eq.s32 p, %1, r0;
          @p add.u32 r1, r1, %0;
          @p mov.u32 %0, r1;
        })"
        : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
#endif
    }
    // shfl.sync wraps around: so thread 0 gets the value of thread 31
    bool leader = val != __shfl_sync(allmsk, val, lane - 1);
    auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
    auto total = __popc(OK); // the total number of unique numbers found

    auto lanelt = (1 << lane) - 1;
    auto idx = __popc(OK & lanelt);

    printf("%d: val = %d; num = %d; total: %d; idx = %d; leader: %d\n", lane, val, num, total, idx, leader);

    __shared__ uint32_t sh[64];
    if(leader) {   // here we need shared memory :(
        sh[idx] = val;
        sh[idx + 32] = num;
    }
    __syncthreads();

    if(lane < total) {
        val = sh[lane], num = sh[lane + 32];
    } else {
        val = 0xDEADBABE, num = 0;
    }
    printf("%d: final val = %d; num = %d\n", lane, val, num);
}

Here is my GPU output:

0: val = 27
1: val = 27
2: val = 28
3: val = 28
4: val = 28
5: val = 28
6: val = 29
7: val = 29
8: val = 29
9: val = 29
10: val = 30
11: val = 30
12: val = 30
13: val = 30
14: val = 31
15: val = 31
16: val = 31
17: val = 31
18: val = 32
19: val = 32
20: val = 32
21: val = 32
22: val = 32
23: val = 33
24: val = 33
25: val = 33
26: val = 33
27: val = 34
28: val = 34
29: val = 34
30: val = 34
31: val = 35
0: val = 27; num = 2; total: 9; idx = 0; leader: 1
1: val = 27; num = 1; total: 9; idx = 1; leader: 0
2: val = 28; num = 4; total: 9; idx = 1; leader: 1
3: val = 28; num = 3; total: 9; idx = 2; leader: 0
4: val = 28; num = 2; total: 9; idx = 2; leader: 0
5: val = 28; num = 1; total: 9; idx = 2; leader: 0
6: val = 29; num = 4; total: 9; idx = 2; leader: 1
7: val = 29; num = 3; total: 9; idx = 3; leader: 0
8: val = 29; num = 2; total: 9; idx = 3; leader: 0
9: val = 29; num = 1; total: 9; idx = 3; leader: 0
10: val = 30; num = 4; total: 9; idx = 3; leader: 1
11: val = 30; num = 3; total: 9; idx = 4; leader: 0
12: val = 30; num = 2; total: 9; idx = 4; leader: 0
13: val = 30; num = 1; total: 9; idx = 4; leader: 0
14: val = 31; num = 4; total: 9; idx = 4; leader: 1
15: val = 31; num = 3; total: 9; idx = 5; leader: 0
16: val = 31; num = 2; total: 9; idx = 5; leader: 0
17: val = 31; num = 1; total: 9; idx = 5; leader: 0
18: val = 32; num = 5; total: 9; idx = 5; leader: 1
19: val = 32; num = 4; total: 9; idx = 6; leader: 0
20: val = 32; num = 3; total: 9; idx = 6; leader: 0
21: val = 32; num = 2; total: 9; idx = 6; leader: 0
22: val = 32; num = 1; total: 9; idx = 6; leader: 0
23: val = 33; num = 4; total: 9; idx = 6; leader: 1
24: val = 33; num = 3; total: 9; idx = 7; leader: 0
25: val = 33; num = 2; total: 9; idx = 7; leader: 0
26: val = 33; num = 1; total: 9; idx = 7; leader: 0
27: val = 34; num = 4; total: 9; idx = 7; leader: 1
28: val = 34; num = 3; total: 9; idx = 8; leader: 0
29: val = 34; num = 2; total: 9; idx = 8; leader: 0
30: val = 34; num = 1; total: 9; idx = 8; leader: 0
31: val = 35; num = 1; total: 9; idx = 8; leader: 1
0: final val = 27; num = 2
1: final val = 28; num = 4
2: final val = 29; num = 4
3: final val = 30; num = 4
4: final val = 31; num = 4
5: final val = 32; num = 5
6: final val = 33; num = 4
7: final val = 34; num = 4
8: final val = 35; num = 1
9: final val = -559039810; num = 0
10: final val = -559039810; num = 0
11: final val = -559039810; num = 0
12: final val = -559039810; num = 0
13: final val = -559039810; num = 0
14: final val = -559039810; num = 0
15: final val = -559039810; num = 0
16: final val = -559039810; num = 0
17: final val = -559039810; num = 0
18: final val = -559039810; num = 0
19: final val = -559039810; num = 0
20: final val = -559039810; num = 0
21: final val = -559039810; num = 0
22: final val = -559039810; num = 0
23: final val = -559039810; num = 0
24: final val = -559039810; num = 0
25: final val = -559039810; num = 0
26: final val = -559039810; num = 0
27: final val = -559039810; num = 0
28: final val = -559039810; num = 0
29: final val = -559039810; num = 0
30: final val = -559039810; num = 0
31: final val = -559039810; num = 0

Question Is it possible to do this without using shared memory? Somehow, I cannot figure it out with all these brain-twisting shuffle intrinsics..

2

There are 2 best solutions below

0
On BEST ANSWER

I think I found the solution: as paleonix also pointed out, the problem is that we need to compute the Nth bit set.

There is actually pretty interesting PTX intrinsic called fns.b32 which does exactly that. However, on my SM30 architecture it maps to something crazy when I run disassembler.

Anyway, we also have the fast popcount intrinsic on GPU which can be used to compute the position of the Nth bit set in logarithmic time. Below is the complete code which now does not require shared memory at all:

EDITED: interestingly enough, apart from NVIDIA, AMD seems to provide a so-called "warp_permute" intruction which is an opposite of __shfl_sync in the sense that threads from a warp write to some destination lane: AMD warp_permute.

EDITED: small optimization using BFE intrinsic

#define PRINTZ(fmt, ...) printf(fmt"\n", ##__VA_ARGS__)

// extracts bitfield from src of length 'width' starting at startIdx
__device__ __forceinline__ uint32_t bfe(uint32_t src, uint32_t startIdx, uint32_t width)
{
    uint32_t bit;
    asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"(src), "r"(startIdx), "r"(width));
    return bit;
}

__device__ __inline__ void sorted_seq_histogram()
{
    uint32_t tid = threadIdx.x, lane = tid % 32;
    uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced

    PRINTZ("%d: val = %d", lane, val);
    uint32_t num = 1;

    const uint32_t allmsk = 0xffffffffu, shfl_c = 31;

    // shfl.sync wraps around: so thread 0 gets the value of thread 31
    bool leader = val != __shfl_sync(allmsk, val, lane - 1);
    auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
    uint32_t pos = 0, N = lane+1; // each thread searches Nth bit set in 'OK' (1-indexed)

    for(int i = 1; i <= 16; i *= 2) {

        uint32_t j = 16 / i;
        uint32_t mval = bfe(OK, pos, j); // extract j bits starting at pos from OK
        auto dif = N - __popc(mval);
        if((int)dif > 0) {
            N = dif, pos += j;
        }

#if 0
        uint32_t xval = __shfl_down_sync(allmsk, val, i),
                 xnum = __shfl_down_sync(allmsk, num, i);
        if(lane + i < 32) {
            if(val == xval)
                num += xnum;
        }
#else  // this is a (hopefully) optimized version of the code above
        asm(R"({
          .reg .u32 r0,r1;
          .reg .pred p;
          shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
          shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
          @p setp.eq.s32 p, %1, r0;
          @p add.u32 r1, r1, %0;
          @p mov.u32 %0, r1;
        })"
        : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
#endif
    }
    num = __shfl_sync(allmsk, num, pos); // read from pos-th thread
    val = __shfl_sync(allmsk, val, pos); // read from pos-th thread

    auto total = __popc(OK); // the total number of unique numbers found
    if(lane >= total) {
        num = 0xDEADBABE;
    }
    PRINTZ("%d: final val = %d; num = %d", lane, val, num);
}

And the program output:

0: val = 27
1: val = 27
2: val = 28
3: val = 28
4: val = 28
5: val = 28
6: val = 29
7: val = 29
8: val = 29
9: val = 29
10: val = 30
11: val = 30
12: val = 30
13: val = 30
14: val = 31
15: val = 31
16: val = 31
17: val = 31
18: val = 32
19: val = 32
20: val = 32
21: val = 32
22: val = 32
23: val = 33
24: val = 33
25: val = 33
26: val = 33
27: val = 34
28: val = 34
29: val = 34
30: val = 34
31: val = 35
0: final val = 27; num = 2;
1: final val = 28; num = 4;
2: final val = 29; num = 4;
3: final val = 30; num = 4;
4: final val = 31; num = 4;
5: final val = 32; num = 5;
6: final val = 33; num = 4;
7: final val = 34; num = 4;
8: final val = 35; num = 1;
9: final val = 35; num = -559039810;
10: final val = 35; num = -559039810;
11: final val = 35; num = -559039810;
12: final val = 35; num = -559039810;
13: final val = 35; num = -559039810;
14: final val = 35; num = -559039810;
15: final val = 35; num = -559039810;
16: final val = 35; num = -559039810;
17: final val = 35; num = -559039810;
18: final val = 35; num = -559039810;
19: final val = 35; num = -559039810;
20: final val = 35; num = -559039810;
21: final val = 35; num = -559039810;
22: final val = 35; num = -559039810;
23: final val = 35; num = -559039810;
24: final val = 35; num = -559039810;
25: final val = 35; num = -559039810;
26: final val = 35; num = -559039810;
27: final val = 35; num = -559039810;
28: final val = 35; num = -559039810;
29: final val = 35; num = -559039810;
30: final val = 35; num = -559039810;
31: final val = 35; num = -559039810;
1
On

One can find out the the lane from which each thread needs to shuffle and then just use __shfl_sync. The only problem/annoyance is that there is no way known to me to do this without a loop.

The needed operation is to find the "index" of the nth set bit in OK, where n is the lane of a thread. The SO question Given a binary number, how to find the nth set bit from the right in O(1) time? is about this problem but its answers only show iterative solutions. As that question is not concerned with any programming language or intrinsics however, it is possible that there is some way to cleverly use integer intrinsics for this.

Either way, the following works for me:

    // ... second printf
    auto src = lane;
    auto cnt = -1;
    for (int i = 0; i < warpSize; ++i) {
        if ((OK >> i) & 0x1 == 0x1) {
            ++cnt;
            if (cnt == lane) {
                src = i;
                break;
            }
        }
    }
    val = __shfl_sync(allmsk, val, src);
    num = __shfl_sync(allmsk, num, src);
    if (lane >= total) {
        val = 0xDEADBABE;
        num = 0;
    }
    // third printf ...

I don't know how it compares in terms of performance (it should be measured without the print statements).