Custom attention function slow when training

64 Views Asked by At

I've been trying to implement a custom attention function in a standard gpt-2 style transformer model. I replaced the scaled dot product with negative Euclidean distance and everything seems to be working except that training is extremely slow. With normal dot product attention the model trains in a few minutes. With my implementation it seems like it will take at least a day to train. The dataset I'm using is about 30mb taken from the Pile. I'm not super familiar with Pytorch so I don't know if my implementation is as efficient as it could be.

Below is my attempt at a custom attention function.

def CustomAttention(A: Float[Tensor, "batch posn_q n_heads d_head"], 
                    B: Float[Tensor, "batch posn_k n_heads d_head"]) -> Float[Tensor, "batch n_heads posn_q posn_k"]:
    A_cast = t.permute(A, (0, 2, 1, 3)).unsqueeze(-2)
    B_cast = t.permute(B, (0, 2, 1, 3)).unsqueeze(-3)
    diff = A_cast - B_cast
    square = diff**2
    sum = t.sum(square, dim=-1)
    return -sum

Basically, the code relies on broadcasting to calculate the elementwise difference for every query-key pair. All of my training is being done locally on a 3080ti and it seems to be using 100% of my gpu as expected. Is there anything I can do to make this run faster?

1

There are 1 best solutions below

0
On

When you say that it trains slowly, do you mean that the time-per-batch has worsened, or that the loss-improvement-per-batch has worsened? If it is the latter, this is not entirely surprising. Negative Euclidean distance attention does not produce the same training dynamics as scaled dot-product attention. A recent preprint on arXiv documents this phenomenon in a different setting:

C McCarter, “Inverse distance weighting attention.” Associative Memory & Hopfield Networks Workshop @ NeurIPS, 2023.