In the TensorFlow Keras implementation of Multi-Head Attention, instead of evaluating the numerator first like in
they evaluate Q/√dₖ first and put comment
Note: Applying scalar multiply at the smaller end of einsum improves XLA performance, but may introduce slight numeric differences in the Transformer attention head.
How is it faster this way? Wouldn't the division after einsum be equally as fast?

What the comment suggest is that the the number of elements in
keyis less than the number of elements inqueryorattention_scoresin the following equation.Given the dimensions
Assuming that
_dot_product_equationis simply doing the batched matrix multiplication, if Q isT x N, and Q isS x N, the productQ @ K.TisT x S, ifS > Nthe number of multiplications is expected to be smaller on the left.But either way that should not be the dominant part except if
S > T * N(or XLA has a bug).