In the original paper of MoCo, it said that:
Using a queue can make the dictionary large, but it also makes it intractable to update the key encoder by back-propagation (the gradient should propagate to all samples in the queue).
First I thought that the main reason that the bp cannot imply on key encoder is that the queue operation is not differentable. But It seems not true. You can compute the gradient of all samples in the queue, then bp should be performed properly. See the code at the bottom.
So WHAT IS THE REAL REASON THAT THE BP IS INTRACTABLE FOR KEY ENCODER? In my opinion, I think may be because of the large size of the queue (dictionary) which makes the memory explosive.
q = nn.Linear(768,128)
k = nn.Linear(768,128)
bs = 64
ks = 4095
model = nn.ModuleList([q,k])
x = torch.randn(bs, 768)
optim = torch.optim.SGD(model.parameters(),lr=0.01)
loss = nn.CrossEntropyLoss()
def forward(x):
xq = q(x)
xk = k(x + 0.1)
que = torch.rand(ks,128)
pos = torch.einsum("nc,nc->n",xq,xk)
neg = torch.einsum("nc,kc->nk",xq,que)
out = torch.cat([pos.unsqueeze(-1),neg],dim=1)
t = torch.zeros(out.shape[0],dtype=torch.long)
l = loss(out,t)
return l
loss = forward(x)
loss.backward()
optim.step()
I consider you are correct, becuase the key reason is related to the large size of the queue, which can make the memory requirements explode.
In the MoCo framework, you maintain a queue of encoded key representations from a large number of historical data samples. When back-propagation is performed on the key encoder, you would need to compute gradients for all the samples in the queue. The memory requirements for storing these gradients, especially when the queue is large, can become prohibitively high.
So it has led to the adoption of a momentum update strategy for modifying the parameters of the key encoder.
MoCo paper: https://arxiv.org/pdf/1911.05722.pdf