refering to the attention maps in VIT transformers example in: https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old
This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to. and How and where in the code the x value is passed to the function my_forward.
def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x).reshape(
B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0)
This requires a little code inspection but you can easily find the implementation if you look in the right places. Let us start with your snippet.
The
my_forward_wrapperfunction is a function generator that definesmy_forwardand returns it. This implementation is overwriting the implementation of the last block attention layerblocks[-1].attnof the loaded model"deit_small_distilled_patch16_224".What the
xcorresponds to is the output of the previous block. To understand, you can dive into the source code of timm. The model loaded in the script isdeit_small_distilled_patch16_224which returns aVisionTransformerDistilledinstance. The blocks are defined in theVisionTransformerclass. There aren=depthblocks defined sequentially. The default block definition is given byBlockin which attn is implemented byAttention, the details are given here:While the implementation - that you provided - overwriting it is:
The idea is that the attention map is being cached as an attribute to the attention layer with
attn_obj.attn_map = attn, such that it can be inspected after inference.