How can I get the class token out of the output_hidden_states?

12 Views Asked by At
for i in range(len(ds['train'])):     
    img=ds['train'][i]['image']     
    encoding = feature_extractor(img, return_tensors='pt').to(device)     
    with torch.no_grad():       
        outputs = model(**encoding,output_hidden_states=True)

I'm conducting a multi-class classification experiment using a pretrained vision transformer (vit) from huggingface. I tried this code above and got a tuple comprising 13 tensors. Each tensor shape is 1x197x768. My question is: where can I find the class token feature vector ( the input to the MLP head)?

0

There are 0 best solutions below