I've been trying to implement a custom attention function in a standard gpt-2 style transformer model. I replaced the scaled dot product with negative Euclidean distance and everything seems to be working except that training is extremely slow. With normal dot product attention the model trains in a few minutes. With my implementation it seems like it will take at least a day to train. The dataset I'm using is about 30mb taken from the Pile. I'm not super familiar with Pytorch so I don't know if my implementation is as efficient as it could be.
Below is my attempt at a custom attention function.
def CustomAttention(A: Float[Tensor, "batch posn_q n_heads d_head"],
B: Float[Tensor, "batch posn_k n_heads d_head"]) -> Float[Tensor, "batch n_heads posn_q posn_k"]:
A_cast = t.permute(A, (0, 2, 1, 3)).unsqueeze(-2)
B_cast = t.permute(B, (0, 2, 1, 3)).unsqueeze(-3)
diff = A_cast - B_cast
square = diff**2
sum = t.sum(square, dim=-1)
return -sum
Basically, the code relies on broadcasting to calculate the elementwise difference for every query-key pair. All of my training is being done locally on a 3080ti and it seems to be using 100% of my gpu as expected. Is there anything I can do to make this run faster?
When you say that it trains slowly, do you mean that the time-per-batch has worsened, or that the loss-improvement-per-batch has worsened? If it is the latter, this is not entirely surprising. Negative Euclidean distance attention does not produce the same training dynamics as scaled dot-product attention. A recent preprint on arXiv documents this phenomenon in a different setting:
C McCarter, “Inverse distance weighting attention.” Associative Memory & Hopfield Networks Workshop @ NeurIPS, 2023.