Dot product of 3D vectors in webassembly

162 Views Asked by At

I want to calculate the dot product of two vectors of 3 elements. I was looking on the wasm vector instructions and there are only a single dot instruction:

i32x4.dot_i16x8_s

it's type is two v128 as input and returns a v128 according to wasm specification.

According to wikipedia:

the dot product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number

but the name of the wasm instructions hints that it expects the two input vectors to be organized as i32x4 and i16x8.

This does not really makes sense for me, since the input vectors are not in the same number of elements. In addition, I don't understand in what format the returning v128 is organized, is is i32x4 or i16x8 or something else?

Also what is the correct wat to calculate the dot product of two 3D elements using 4D calculations, can I pad the vectors with 1's and perhaps manipulate the answer in some way?

If I would elaborate about a WAT code, I would guess something like:

(module
  (func $my_function (result i32)
    v128.const i32x4 1 3 -5 0
    v128.const i16x8 4 -2 -1 0 0 0 0 0
    i32x4.dot_i16x8_s
    i32x4.extract_lane 0
  )
)

the wikipedia article has an example:

dot([1 3 -5], [4 -2 -1]) returns 3

What is the proper way to do that using webassembly?

2

There are 2 best solutions below

0
Aki Suihkonen On BEST ANSWER

The dot product in wasm must be interpreted as i32x4 being the output, and i16x8 being the input. This indeed corresponds to Intel pmaddwd or pair-wise multiply add words into double words. The intrinsic is also implementation specific, as -32768 **2 * 2 overflows int32_t.

To compute dot product of two 3-element vectors, one must just unroll it as a[0]*b[0]+a[1]*b[1]+a[2]*b[2].

As suggested in the comments, it's often best to organise the data in a format, where one can compute 4 independent dot products vertically at a time.

struct vec3x4 {
     int32_t x[4];
     int32_t y[4];
     int32_t z[4];
};

Allocating a full v128 for just those three values might anyway help the JIT in optimising - even in the case that there's no dot product instructions on Intel (or Arm64) for integer values.

struct vec3_plus {
    int x,y,z,w{0};
};

The dot product between two of these values can anyway benefit from parallel execution of 3 real and 1 dummy multiplies, followed by either parallel or scalar reduction of the terms.

4
Jonas On

As @harold commented and @Aki mentioned, the i32x4.dot_i16x8_s Webassembly instruction seem to correspond to pmaddwd.

I did find a better (compared to the specification) explanation of the wasm instruction in a proposal document:

Integer dot product

i32x4.dot_i16x8_s(a: v128, b: v128) -> v128

Lane-wise multiply signed 16-bit integers in the two input vectors and add adjacent pairs of the full 32-bit results.

Altough the instruction cannot be used directly as a dot product operator for two vectors of 3 elements, as I wanted, it is still useful for such an implementation.

The description of pmaddwd is:

Multiplies the individual signed words of the destination operand (first operand) by the corresponding signed words of the source operand (second operand), producing temporary signed, doubleword results. The adjacent double-word results are then summed and stored in the destination operand.

And this illustration is helpful: pmaddwd

Since I want to use vectors of 3 elements, I can pad the input vectors with 0's, then extract the sums from the lanes, and finally do the remaining additions.

Dot product implementation for 3D vectors

Considering the example in the wikipedia article:

[1 3 -5] dot [4 -2 -1] = 3

I implemented a function that takes two 3 element numbers and calculate the dot product of those vectors:

(module
  (func (export "calc_dot") (result i32)
    i32.const 1
    i32.const 3
    i32.const -5

    i32.const 4
    i32.const -2
    i32.const -1

    call $dot3
    return
  )

  (func $dot3 (param i32) (param i32) (param i32)
              (param i32) (param i32) (param i32) (result i32)
    (local v128)

    ;; create vector from first 3 params
    (i16x8.splat (i32.const 0))
    (i16x8.replace_lane 0 (local.get 0))
    (i16x8.replace_lane 1 (local.get 1))
    (i16x8.replace_lane 2 (local.get 2))

    ;; create vector from last 3 params
    (i16x8.splat (i32.const 0))
    (i16x8.replace_lane 0 (local.get 3))
    (i16x8.replace_lane 1 (local.get 4))
    (i16x8.replace_lane 2 (local.get 5))

    ;; integer dot product
    (local.set 6 (i32x4.dot_i16x8_s))

    (i32x4.extract_lane 0 (local.get 6))
    (i32x4.extract_lane 1 (local.get 6))
    i32.add

    return
  )
)