How is it possible to use a pre-trained ViT backbone of a masked autoencoder in downstream tasks?

51 Views Asked by At

When pre-training a ViT backbone through a masked autoencoder architecture, the input patches are randomly masked and the unsmasked patches are fed to the encoder layers of the ViT, as shown in the KERAS tutorial on (masked image modeling.

If I understood correctly, the approach is to pre-train the ViT backbone in the masked image modeling paradigm and then use the pre-trained ViT (which includes the transformer encoder blocks) for a downstream model, which may serve a specific downstream task, like image classification for example, by using the CLS-token aproach or some sort of pooling operation. The question that bothers me is now:

When adopting the pre-trained layers for the downstream task one is usually not masking anything, thus the input shape for the encoder has to change. Since the encoders are transformer encoder blocks, utilizing self-attention (or mulit-head attention, actually), where wheight matrices for query, key and value are initialized and optimized during pre-training, how is it possible to use this layer when the input is now the whole image, or respectively the whole sequence of embedded patches?

For example (for simplicity, lets not consider batch_size here) we have an image, which gets split into let's say 100 (16x16) patches, then one obtains an input sequence of shape [100, embedding_dimension]. Then, say, 80% of the patches are getting masked, so we have masked patches with shape [80, embedding_dimension] and unmasked patches of shape [20, embedding_dimension]. We feed the unmasked patches to the encoder layers performing multi-head-attention, and train the network as proposed in the keras tutorial (or the original paper, as you like). When switching to the downstream task lets consider again images of the same size, so we have again input to the model of shape [100, embedding_dimension]. But we're not masking anymore, so the encoder receives the whole sequence of 100 patches.

How does that fit the pre-trained weight matrices? What am I getting wrong here? Thanks for your help in advance!

I tried implemeningt this in tensorflow and was able to pre-train the masked autoencoder with a masking rate of .8, then extracting the encoder layers of the pre-trained model and plug it into a downstream model, which gets the whole sequence as input. No errors were thrown ...

0

There are 0 best solutions below