This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to

31 Views Asked by At

refering to the attention maps in VIT transformers example in: https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old

This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to. and How and where in the code the x value is passed to the function my_forward.

def my_forward(x):
        B, N, C = x.shape

        qkv = attn_obj.qkv(x).reshape(
            B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0) 
1

There are 1 best solutions below

0
Ivan On BEST ANSWER

This requires a little code inspection but you can easily find the implementation if you look in the right places. Let us start with your snippet.

  • The my_forward_wrapper function is a function generator that defines my_forward and returns it. This implementation is overwriting the implementation of the last block attention layer blocks[-1].attn of the loaded model "deit_small_distilled_patch16_224".

    model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
    
  • What the x corresponds to is the output of the previous block. To understand, you can dive into the source code of timm. The model loaded in the script is deit_small_distilled_patch16_224 which returns a VisionTransformerDistilled instance. The blocks are defined in the VisionTransformer class. There are n=depth blocks defined sequentially. The default block definition is given by Block in which attn is implemented by Attention, the details are given here:

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x) \
                  .reshape(B, N, 3, self.num_heads, self.head_dim) \
                  .permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
    
        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
    
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    

    While the implementation - that you provided - overwriting it is:

    def my_forward(x):
        B, N, C = x.shape
        qkv = attn_obj.qkv(x) \
                .reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads) \
                .permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
    
        attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
        attn = attn.softmax(dim=-1)
        attn = attn_obj.attn_drop(attn)
        attn_obj.attn_map = attn
        attn_obj.cls_attn_map = attn[:, :, 0, 2:]
    
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = attn_obj.proj(x)
        x = attn_obj.proj_drop(x)
        return x
    

    The idea is that the attention map is being cached as an attribute to the attention layer with attn_obj.attn_map = attn, such that it can be inspected after inference.