How do you do signed 32bit widening multiplication on SSE2?

266 Views Asked by At

This question came up when reviewing the WebAssembly SIMD proposal for extended multiplication.

To support older hardware, we need to support SSE2 and the only vector multiplication operation for 32 bit integers is pmuludq. (Signed pmuldq was only added in SSE4.1)

(non-widening is relatively easy; shuffle to feed 2x pmuludq and take the low halves of the 4 results to emulate SSE4.1 pmulld).

2

There are 2 best solutions below

9
On BEST ANSWER

mulhs(a, b) = mulhu(a, b) - (a < 0 ? b : 0) - (b < 0 ? a : 0)

Using that, two signed double-width products can be computed like this,

__m128i mul_epi32(__m128i a, __m128i b) {
    a = _mm_shuffle_epi32(a, _MM_SHUFFLE(3, 1, 1, 0));
    b = _mm_shuffle_epi32(b, _MM_SHUFFLE(3, 1, 1, 0));
    __m128i unsignedProduct = _mm_mul_epu32(a, b);
    __m128i threshold = _mm_set_epi32(INT_MIN, 0, INT_MIN, 0);
    __m128i signA = _mm_cmplt_epi32(a, threshold);
    __m128i signB = _mm_cmplt_epi32(b, threshold);
    __m128i x = _mm_shuffle_epi32(_mm_and_si128(signA, b), _MM_SHUFFLE(2, 3, 0, 1));
    __m128i y = _mm_shuffle_epi32(_mm_and_si128(signB, a), _MM_SHUFFLE(2, 3, 0, 1));
    return _mm_sub_epi32(_mm_sub_epi32(unsignedProduct, x), y);
}

That saves a couple of operations over the other proposal, but it's very close and now it includes a load which could be bad if this code is cold.

8
On

Big shout out to @GeDaMo on ##asm for helping come up with this solution.

Godbolt

C/C++:

#include <xmmintrin.h>
#include <stdint.h>
#include <tmmintrin.h>
#include <smmintrin.h>
#include <cstdio>

typedef int32_t int32x4_t __attribute__((vector_size(16))) __attribute__((aligned(16)));
typedef int64_t int64x2_t __attribute__((vector_size(16))) __attribute__((aligned(16)));

int64x2_t multiply32_low_s(int32x4_t a, int32x4_t b) {
    auto aSigns = a >> 31;
    auto bSigns = b >> 31;
    auto aInt = a ^ aSigns;
    aInt -= aSigns;
    auto bInt = b ^ bSigns;
    bInt -= bSigns;
    const auto shuffleMask = _MM_SHUFFLE(1,1,0,0);
    auto absProd = _mm_mul_epu32(_mm_shuffle_epi32((__m128i)aInt, shuffleMask), _mm_shuffle_epi32((__m128i)bInt, shuffleMask));
    auto aSignsInt = _mm_shuffle_epi32((__m128i)aSigns, shuffleMask);
    auto bSignsInt = _mm_shuffle_epi32((__m128i)bSigns,shuffleMask);
    auto prodSigns = aSignsInt ^ bSignsInt;
    absProd ^= prodSigns;
    absProd -= prodSigns;
    return (int64x2_t)absProd;
}

int64x2_t multiply32_high_s(int32x4_t a, int32x4_t b) {
    auto aSigns = a >> 31;
    auto bSigns = b >> 31;
    auto aInt = a ^ aSigns;
    aInt -= aSigns;
    auto bInt = b ^ bSigns;
    bInt -= bSigns;
    const auto shuffleMask = _MM_SHUFFLE(3,3,2,2);
    auto absProd = _mm_mul_epu32(_mm_shuffle_epi32((__m128i)aInt, shuffleMask), _mm_shuffle_epi32((__m128i)bInt, shuffleMask));
    auto aSignsInt = _mm_shuffle_epi32((__m128i)aSigns, shuffleMask);
    auto bSignsInt = _mm_shuffle_epi32((__m128i)bSigns,shuffleMask);
    auto prodSigns = aSignsInt ^ bSignsInt;
    absProd ^= prodSigns;
    absProd -= prodSigns;
    return (int64x2_t)absProd;
}


int main(int argc, char* argv[]) {
    int32x4_t a{-5,500,-5000,50000};
    int32x4_t b{10,-100,-5000,500000000};
    auto c = multiply32_low_s(a,b);
    auto d = multiply32_high_s(a,b);
    printf("%ld %ld\n", c[0],c[1]);
    printf("%ld %ld\n", d[0],d[1]);
}

Assembly

multiply32_low_s(int __vector(4), int __vector(4)):
 movdqa xmm3,xmm0
 movdqa xmm2,xmm1
 psrad  xmm3,0x1f
 psrad  xmm2,0x1f
 pxor   xmm0,xmm3
 pxor   xmm1,xmm2
 psubd  xmm1,xmm2
 psubd  xmm0,xmm3
 pshufd xmm2,xmm2,0x50
 pshufd xmm1,xmm1,0x50
 pshufd xmm0,xmm0,0x50
 pshufd xmm3,xmm3,0x50
 pmuludq xmm0,xmm1
 pxor   xmm2,xmm3
 pxor   xmm0,xmm2
 psubq  xmm0,xmm2
 ret    
 nop    WORD PTR [rax+rax*1+0x0]
multiply32_high_s(int __vector(4), int __vector(4)):
 movdqa xmm3,xmm0
 movdqa xmm2,xmm1
 psrad  xmm3,0x1f
 psrad  xmm2,0x1f
 pxor   xmm0,xmm3
 pxor   xmm1,xmm2
 psubd  xmm1,xmm2
 psubd  xmm0,xmm3
 pshufd xmm2,xmm2,0xfa
 pshufd xmm1,xmm1,0xfa
 pshufd xmm0,xmm0,0xfa
 pshufd xmm3,xmm3,0xfa
 pmuludq xmm0,xmm1
 pxor   xmm2,xmm3
 pxor   xmm0,xmm2
 psubq  xmm0,xmm2
 ret    
 nop    WORD PTR [rax+rax*1+0x0]