Visualize attention map for vision transformer · huggingface/pytorch-image-models · Discussion #1232 (original) (raw)
Hi, I want to extract attention map from pretrained vision transformer for specific image.
How I can do that?
You must be logged in to vote
Hi @kiashann
This is toy examples to visualize whole attention map and attention map only for class token. (see here for more information)
import numpy as np from PIL import Image import matplotlib.pyplot as plt from timm.models import create_model import torch.nn.functional as F from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
def to_tensor(img): transform_fn = Compose([Resize(249, 3), CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) return transform_fn(img)
def show_img(img): img = np.asarray(img) plt.figure(figsize=(10, 10)) plt.imshow(img) plt.axis('off') plt.show()
def show_img2(img1, img2, alpha=0.8): img1 = np.asarray(img1) img2 = np.asarray(img2) plt.figure(figsize=(10, 10)) plt.imshow(img1) plt.imshow(img2, alpha=alpha) plt.axis('off') plt.show()
def my_forward_wrapper(attn_obj): def my_forward(x): B, N, C = x.shape qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
attn = attn.softmax(dim=-1)
attn = attn_obj.attn_drop(attn)
attn_obj.attn_map = attn
attn_obj.cls_attn_map = attn[:, :, 0, 2:]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = attn_obj.proj(x)
x = attn_obj.proj_drop(x)
return x
return my_forward
img = Image.open('n02102480_Sussex_spaniel.JPEG') x = to_tensor(img)
model = create_model('deit_small_distilled_patch16_224', pretrained=True) model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
y = model(x.unsqueeze(0)) attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach() cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14, 14).detach()
img_resized = x.permute(1, 2, 0) * 0.5 + 0.5 cls_resized = F.interpolate(cls_weight.view(1, 1, 14, 14), (224, 224), mode='bilinear').view(224, 224, 1)
show_img(img) show_img(attn_map) show_img(cls_weight) show_img(img_resized) show_img2(img_resized, cls_resized, alpha=0.8)
attention map for last layer (198 x 198 (=196(img) + 1(cls) + 1(distill)))
class attention map for last layer (14 x 14)
class attention map over image
You must be logged in to vote
4 replies
The attention scores are a bit scattered here, usually the cls token focuses on certain patches and they are consistent. Are you sure taking the mean is a good idea? Also it might be better to rollout the cls token over multiple blocks.
So I tried taking the product of the cls token weights over other blocks like this but it shows the error
RuntimeError Traceback (most recent call last) Cell In[18], line 13 11 # Forward pass through all blocks 12 for block in model.blocks: ---> 13 x, attn_map = block.attn.forward(x) 14 outputs.append(x) 15 attn_maps.append(attn_map)
Cell In[8], line 4, in my_forward_wrapper..my_forward(x) 2 def my_forward(x): 3 B, N, C = x.shape ----> 4 qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4) 5 q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 7 attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs) 1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)
Code:
model = create_model('deit_small_distilled_patch16_224', pretrained=True)
Replace forward function in all blocks
for block in model.blocks: block.attn.forward = my_forward_wrapper(block.attn)
Forward pass through the model
outputs = [] attn_maps = [] cls_weights = []
Forward pass through all blocks
for block in model.blocks: x, attn_map = block.attn.forward(x) outputs.append(x) attn_maps.append(attn_map) cls_weights.append(block.attn.cls_attn_map.min(dim=1).values.view(14, 14).detach())
Combine class scores of all blocks
cls_weight_combined = torch.prod(torch.stack(cls_weights), dim=0)
Resize input image and class weights
img_resized = x.permute(0, 2, 3, 1) * 0.5 + 0.5 cls_resized = F.interpolate(cls_weight_combined.view(1, 1, 14, 14), (224, 224), mode='bilinear').view(224, 224, 1)
Visualize
show_img(image) show_img(attn_maps[-1]) # Attention map from the last block show_img(cls_weight_combined) # Combined class weights show_img(img_resized.squeeze()) # Squeeze the batch dimension show_img2(img_resized.squeeze(), cls_resized, alpha=0.8) # Squeeze the batch dimension
Hi @arnavs04
That's a good question. You're right. There may be a better way to merge different attention maps. The code I give contains very basic ways to merge different attention maps.
The error you posted looks like a dimension mismatch error:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)
I suggest applying a patch embedding layer to the input image before passing it into the attention block. You should also ensure that other layers (e.g., normalization, mlp) are applied appropriately.
Thank you.
Hankyul.
Ohh, thank you for the help!!
@kiashann
Thank you for your valuable code. the whole code is working fine but I just need to understand how these lines work : model.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn) and attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach() so I ran these lines of code in the console : attn_obj=model.blocks[-1].attn & qkv = attn_obj.qkv(x) but got this error (RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)). I'd like to know why. I need to make sure whether x is the transformed image or some other variable – When I debugged the code I found out that B,N,C from x.shape are 1 &384 &198, which are different from the dimensions of the transformed image
You must be logged in to vote
5 replies
Hi @mae338
I am happy to help you.
- In
y=model(x.unsqueeze(0))
,x.unsqueeze(0)
is a transformed image that has shapes as(1, 3, 224, 224)
. - In
model.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn)
, we replace the originalforward
to ourmy_forward_wrapper
to save attention map as an instance variable ofmodel.blocks[-1].attn
. - In
model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
, we average attention map in head-dimension.mean(dim=1)
so that we can display whole attention. - To understand better, we recommend you print out every shape of the tensor, which helps you to see the overall workflow of ViT.
Thank you.
Hankyul
Thank you so much, Hankyul. I appreciate your effort to help me. I'd be pleased if you could help me understand this as well:
What about def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x)?
more especifically, the x parameter? what does it refer to?
and why did you create the function my_forward?
I also found out that the function my_forward must return x or else it doesn't work and the whole program won't work. I wonder why?
Hi @mae338
I hope this could help you.
- We replace the original Attention
forward()
method tomy_forward()
method that only inserts two extra codes for saving attention map (attn_obj.attn_map = attn
,attn_obj.cls_attn_map = attn[:, :, 0, 2:]
). Thus, everything (function, signature, and return value) should be the same as the original function except for additional code. x
is image tokens (Batch x N x D
) input to attention. Image tokens are tokenized by a linear layer, and their shape depends on the model's size. In our case, DeiT-small/16 splits the whole image into 196 (N=198
, with 2 extra class, distill tokens) patches and has 384 (D=384
) channel dimensions.- Since the attention block passes the output to the next block such as the MLP block,
my_forward()
should also pass the return value. If you skip passing the return value to the next block, the next blocks get aNone
value, which is an unexpected situation for them, thereby generating an error.
Thank you.
Hankyul
Hi @hankyul2,
I’d like to view the outputs of a pretrained vit model (vit-base-batch-16 224) , especially the input to the mlp head. I tried the same function but it didn’t work. Any suggestions?
Thank you
mae
Hi @mae338
I hope this can help you.
You can extract the ultimate features of a pre-trained ViT by y = model.forward_head(x, pre_logits=True)
and visualize them for your purpose. Since it was a long time ago when I upload initial code in first comment, I copied and modified them as:
dependency
!wget https://user-images.githubusercontent.com/31476895/167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png !pip install -Uq fastai timm==0.6.13 huggingface_hub ##########################################
code
import numpy as np from PIL import Image from timm.models import create_model from torch import nn from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
def to_tensor(img): transform_fn = Compose([Resize(249), CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) return transform_fn(img)
img = Image.open('167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png').convert('RGB') x = to_tensor(img)
model = create_model('deit_small_distilled_patch16_224', pretrained=True) x = model.forward_features(x.unsqueeze(0)) y = model.forward_head(x, pre_logits=True) print(y.shape) ##########################################
output
torch.Size([1, 384]) ##########################################
Thank you.
Hankyul
I would like to apply this code to the 'vit_small_patch16_384' model from timm. How should I modify the code for this purpose?
(I understand that '224' in the given code refers to the image size, but how is '14' determined?)
I apologize if this is due to my lack of knowledge.
attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14, 14).detach()
img_resized = x.permute(1, 2, 0) * 0.5 + 0.5
cls_resized = F.interpolate(cls_weight.view(1, 1, 14, 14), (224, 224), mode='bilinear').view(224, 224, 1)
You must be logged in to vote
3 replies
Hi @tomos7231
384
means input image resolution to ViT model. If you want to extract an attention map using the code above, you should change the input resolution (224
->384
) and the dimension of patches (14x14
->24x24
) like below code.
attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach() cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(24, 24).detach() # (14->24)
img_resized = x.permute(1, 2, 0) * 0.5 + 0.5 cls_resized = F.interpolate(cls_weight.view(1, 1, 24, 24), (384, 384), mode='bilinear').view(384, 384, 1) # (14->24), (224->384)
14
means the number of patches in each spatial dimension, e.g., H, W of the feature map. This value (14
) is determined by patch size (16
) because each image is divided by 196 patches, each size is16x16
.- If you just want to extract an attention map, I would recommend @rwightman's solution, which is more convenient.
Thank you.
Hankyul
Hi @hankyul2
Well understood. Thank you for reply!
@hankyul2
Thank you for such an insightful explanation above.
- I also tried to apply this code to the 'vit_small_patch16_384' model as shown above, but I encountered the following error.
RuntimeError Traceback (most recent call last)
Cell In[26], line 24
20 # y = model(image)
23 attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
---> 24 cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(24, 24).detach()
26 img_resized = image2.permute(0, 2, 3, 1) * 0.5 + 0.5
27 cls_resized = F.interpolate(cls_weight.view(1, 1, 24, 24), (384, 384), mode='bilinear').view(384, 384, 1)
RuntimeError: shape '[24, 24]' is invalid for input of size 575
2.If I want to apply this model to a regression problem and estimate multiple parameters with a single model, is it possible to check the attention map for each parameter? In a multi-class classification problem, we can check the attention map when a certain class is classified to see which areas are being focused on to recognize that class. However, in the case of a regression problem, how should this be done?
So, this doesn't include the visualization helpers yet, but have added a simpler extraction helper to get the attention activations via one of two methods, fx or hooks.
WIP but can be seen https://github.com/huggingface/pytorch-image-models/pull/2168/files#diff-358e0d5feb2c109ff53d21bc4fa8a6af94566be622b0f1167316216b0036b8b3
import timm
import torch
from timm.utils import AttentionExtract
timm.layers.set_fused_attn(False)
mm = timm.create_model('vit_base_patch16_224')
input = torch.randn(2,3,224,224)
ee = AttentionExtract(mm, method='fx')
oo = ee(input)
for n, t in oo.items():
print(n, t.shape)
blocks.0.attn.softmax torch.Size([2, 12, 197, 197])
blocks.1.attn.softmax torch.Size([2, 12, 197, 197])
blocks.2.attn.softmax torch.Size([2, 12, 197, 197])
blocks.3.attn.softmax torch.Size([2, 12, 197, 197])
blocks.4.attn.softmax torch.Size([2, 12, 197, 197])
blocks.5.attn.softmax torch.Size([2, 12, 197, 197])
blocks.6.attn.softmax torch.Size([2, 12, 197, 197])
blocks.7.attn.softmax torch.Size([2, 12, 197, 197])
blocks.8.attn.softmax torch.Size([2, 12, 197, 197])
blocks.9.attn.softmax torch.Size([2, 12, 197, 197])
blocks.10.attn.softmax torch.Size([2, 12, 197, 197])
blocks.11.attn.softmax torch.Size([2, 12, 197, 197])
You must be logged in to vote
10 replies
Hello @rwightman, thank you for showing us how to extract the attention layers and maintaining your wonderful timm library. I would like to ask, I am using fast.ai alongside with timm models. I've trained a ViT for a classification task in my own dataset. What would be the best way to load the weights of my ViT and visualize the attention activations over my input image using the timm visualization helpers?, thank you
I have a bit of a trivial question here and slightly off-topic to what @rwightman discussed, but related to the attention map that @hankyul2 posted in their initial visualization:
I'm assuming that the ViT in OPs question is being trained for a classification task. Is it necessary that the attention map must contain high activations along the diagonal, similar to the attention maps generated while training seq2seq models? This seems quite unintuitive to me, as I wouldn't expect a classification ViT to attend to self-attend to patches in such a manner. I would instead expect every patch to attend to certain "interesting" parts of the image, almost like having certain columns in the attention map with high activations.
Is my understanding incorrect, or should we expect the attention map to have high activations along the diagonal?
Thank you in advance!
@SarthakJShetty You can see the attention maps here after every block for DeiT, in the initial layers you can see the clear diagonal and how the attention maps then change. Please check out my notebook. It contains visualizations for attention rollout but also without. I hope it can be of help.
My assumption, or rather intuition, behind these diagonal patterns is centered around the class token, which I believe is the reason we don't see information flow immediately, as the class token takes in information from the tokens without much change. However, this is what I think for shallow layers; I only believe that in the later and deeper layers, the attention starts to play out.
Thank you for the clarification @arnavs04! This helps. Now that I think about it, the visualizations that I'm observing are almost always like Attention Map 5 and beyond.
Quick question: when you say "the attention starts to play out.", you mean when the queries actually start attending to relevant parts (and not just predominately self-attending in a diagonal fashion) of the image correct? i.e when the attention actually starts looking non-diagonal and like the attention maps 5-12?
@SarthakJShetty
One of the major reasons why vision transformers (ViTs) are thought to be “better” than CNNs is their ability to share global contextual information right from the beginning (i.e., in the shallow layers). Unlike CNNs, which progress from local to global information, ViTs can, in theory, access global information from the outset. However, attention maps show that not much global information is actually shared in the early stages. This has led to many works proposing modifications to optimize the self-attention mechanism in vision transformers, effectively “tuning” them similarly to CNNs.
In the diagram, you can see how information moves gradually from local to global, resembling the behaviour of a CNN (hence the diagonal structure slowly transitioning to a more globally attended pattern). This example might not be fully representative, as there is usually some global information shared even in the early stages of a vision transformer.
Here is a paper that maybe clear your doubts: Do Vision Transformers See Like Convolutional Neural Networks. I haven't gone through it completely as I had kept it on my reading list.
I am currently working on Vision Transformers. If you have any questions or would like to discuss them with me, let me know! I'll be glad to help!
FYI there's a fix on main for the node/module matching so that outputs will remain in order of traversal (usually matches order of forward pass, at least for timm models) regardless of how many matching names/wildcards are specified.
You must be logged in to vote
0 replies
Hi! @hankyul2 Thanks for your excellent explanation above.
I understood most of it but was still confused about attn_obj.cls_attn_map = attn[:, :, 0, 2:]
.
Why cls_attn_map
is extracted based on the dimension index of attn[:, :, 0, 2:]
?
Thanks!
You must be logged in to vote
5 replies
This is deit architecture which has both a class token and a distillation token for prediction.
When we do index 0
we are looking at the similarity scores of all other tokens wrt to class token. And now when we do 2:
it means, that we're only looking at the similarity tokens of all the patch tokens wrt class token (excluding the class token and distillation token itself)
Thanks for your response @arnavs04! The similarity scores you mentioned are the dot product of two vectors from the query and key matrices. Is that right?
@zichunxx Yup exactly! The (n+2) x (n+2) attention matrix.
Thanks for your generous help! @arnavs04 I have read your notebook which is very thorough and helpful.
I have noticed that some vision transformers are implemented as an encoder without the cls token. In this situation, how do we plot the overlaid image to illustrate which patch is watched with a higher weight? Thanks!
I apologize for the delay, I didn't see the reply.
The attribution map is resized with bilinear interpolation to fit the H x W resolution as the original map. This heatmap is now taken and with 0.5 x heat_map + 0.5 x original_image we get our saliency map. Obviously you can tweak the values instead of 0.5 and 0.5 respectively. This is the goal of post-hoc model agnostic explainability methods for vision transformers.
You must be logged in to vote
0 replies
You must be logged in to vote
1 reply
Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies
Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies
Does directly resizing a 14×14 attention map to 224×224 make sense? I’ve seen this approach used frequently in Visualizing attention rollouts for ViTs, but I’m trying to understand what the resized attention values actually represent. Since the original 14×14 map corresponds to attention over 16×16 patches of a 224×224 image, does interpolating the attention values introduce artifacts or distort their meaning? Would it make more sense to directly map each value of 14x14 attention rollout matrix to a fixed 16×16 region instead, hence getting a 224x224 mask shape? Additionally, I couldn’t find a definitive source explaining why interpolating attention maps is standard practice—is this done purely for visualization, or is there a theoretical justification behind it? Any insights would be greatly appreciated. Thanks in advance :))
You must be logged in to vote
0 replies
Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies