Attention mechanisms have significantly improved the performance of deep learning models in computer vision tasks. This article provides an overview of several popular attention modules that can be easily integrated into object detection models like YOLOv5, YOLOv7, YOLOv8, YOLOv9, and YOLOv10.
SE
Paper: Squeeze-and-Excitation Networks Link: arXiv:1709.01507
The Squeeze-and-Excitation (SE) attention mechanism enhances network performance by learning a weight vector that re-weights feature channels.
Overview
The SE block consists of two main operations:
- Squeeze: Global average pooling compresses each channel's feature map into a single scalar, representing a global descriptor.
- Excitation: Two fully connected layers with a sigmoid activation produce a weight vector (0 to 1) indicating each channel's importance.
These weights are then multiplied with the original feature map to recalibrate channel responses.
Applications and Limitations
SE attention is widely used in image classification, object detection, and segmentation. Its computationally lightweight and effective for channel-wise feature recalibration. However, it only considers channel attention, not spatial attention, which may limit performance in tasks requiring spatial awareness.
Code
import torch
from torch import nn
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super(SEAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
if __name__ == '__main__':
input_tensor = torch.randn(64, 512, 20, 20)
se_attention = SEAttention(channel=512, reduction=8)
output_tensor = se_attention(input_tensor)
print(output_tensor.shape)
A2-Nets: Double Attention Networks
Paper: A2-Nets: Double Attention Networks Link: arXiv:1810.11579
A2-Nets capture and distribute global features using a double attention mechanism.
Core Idea
The network first gathers critical features from the entire spatial extent into a compact set using second-order attention pooling. It then adaptively distributes these features to each location using another attention mechanism.
Code
import torch
from torch import nn
from torch.nn import functional as F
class DoubleAttention(nn.Module):
def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):
super(DoubleAttention, self).__init__()
self.in_channels = in_channels
self.reconstruct = reconstruct
self.c_m = c_m
self.c_n = c_n
self.convA = nn.Conv2d(in_channels, c_m, 1)
self.convB = nn.Conv2d(in_channels, c_n, 1)
self.convV = nn.Conv2d(in_channels, c_n, 1)
if self.reconstruct:
self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)
def forward(self, x):
b, c, h, w = x.shape
assert c == self.in_channels
A = self.convA(x) # b, c_m, h, w
B = self.convB(x) # b, c_n, h, w
V = self.convV(x) # b, c_n, h, w
tmpA = A.view(b, self.c_m, -1) # b, c_m, h*w
attention_maps = F.softmax(B.view(b, self.c_n, -1), dim=-1) # b, c_n, h*w
attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=1) # b, c_n, h*w
global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1)) # b, c_m, c_n
tmpZ = torch.bmm(global_descriptors, attention_vectors) # b, c_m, h*w
tmpZ = tmpZ.view(b, self.c_m, h, w)
if self.reconstruct:
tmpZ = self.conv_reconstruct(tmpZ)
return tmpZ
if __name__ == '__main__':
input_tensor = torch.randn(64, 512, 20, 20)
double_attention = DoubleAttention(512)
output_tensor = double_attention(input_tensor)
print(output_tensor.shape)
BAM: Bottleneck Attention Module
Paper: BAM: Bottleneck Attention Module Link: arXiv:1807.06514
BAM combines channel and spatial attention to improve feature extraction.
Channel and Spatial Attention
- Channel attention: Uses global average pooling followed by fully connected layers to learn channel importance.
- Spatial attention: Uses dilated convolutions to compute a spatial attention map.
Both attention output are combined and applied to the input via a residual connection.
Code
import torch
from torch import nn
def autopad(kernel_size, padding=None, dilation=1):
if dilation > 1:
kernel_size = dilation * (kernel_size - 1) + 1 if isinstance(kernel_size, int) else [dilation * (x - 1) + 1 for x in kernel_size]
if padding is None:
padding = kernel_size // 2 if isinstance(kernel_size, int) else [x // 2 for x in kernel_size]
return padding
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction=16, num_layers=3):
super(ChannelAttention, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
gate_channels = [in_channels]
gate_channels += [in_channels // reduction] * num_layers
gate_channels += [in_channels]
self.ca = nn.Sequential()
self.ca.add_module('flatten', Flatten())
for i in range(len(gate_channels) - 2):
self.ca.add_module(f'fc_{i}', nn.Linear(gate_channels[i], gate_channels[i+1]))
self.ca.add_module(f'bn_{i}', nn.BatchNorm1d(gate_channels[i+1]))
self.ca.add_module(f'relu_{i}', nn.ReLU())
self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))
def forward(self, x):
res = self.avgpool(x)
res = self.ca(res)
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
return res
class SpatialAttention(nn.Module):
def __init__(self, in_channels, reduction=16, num_layers=3, dilation=2):
super(SpatialAttention, self).__init__()
self.sa = nn.Sequential()
self.sa.add_module('conv_reduce', nn.Conv2d(kernel_size=1, in_channels=in_channels, out_channels=in_channels // reduction))
self.sa.add_module('bn_reduce', nn.BatchNorm2d(in_channels // reduction))
self.sa.add_module('relu_reduce', nn.ReLU())
for i in range(num_layers):
self.sa.add_module(f'conv_{i}', nn.Conv2d(kernel_size=3, in_channels=in_channels // reduction, out_channels=in_channels // reduction, padding=autopad(3, None, dilation), dilation=dilation))
self.sa.add_module(f'bn_{i}', nn.BatchNorm2d(in_channels // reduction))
self.sa.add_module(f'relu_{i}', nn.ReLU())
self.sa.add_module('last_conv', nn.Conv2d(in_channels // reduction, 1, kernel_size=1))
def forward(self, x):
res = self.sa(x)
res = res.expand_as(x)
return res
class BAMBlock(nn.Module):
def __init__(self, in_channels=512, reduction=16, dilation=2):
super(BAMBlock, self).__init__()
self.ca = ChannelAttention(in_channels=in_channels, reduction=reduction)
self.sa = SpatialAttention(in_channels=in_channels, reduction=reduction, dilation=dilation)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
sa_out = self.sa(x)
ca_out = self.ca(x)
weight = self.sigmoid(sa_out + ca_out)
out = (1 + weight) * x
return out
if __name__ == '__main__':
input_tensor = torch.randn(64, 512, 7, 7)
bam = BAMBlock(in_channels=512, reduction=16, dilation=2)
output_tensor = bam(input_tensor)
print(output_tensor.shape)
BiFormer: Bi-Level Routing Attention
Paper: BiFormer: Vision Transformer with Bi-Level Routing Attention Link: arXiv:2303.08810
BiFormer introduces a bi-level routing attention mechanism for dynamic sparse attention in vision transformers.
Method
- Region-level routing: A coarse-grained filtering removes irrelevant key-value pairs by constructing a region affinity graph and selecting the top-k connections.
- Token-level attention: Fine-grained token-to-token attention is applied only within the union of remaining candidate regions.
This approach reduces computational cost while maintaining high performance.
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class TopkRouting(nn.Module):
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
self.routing_act = nn.Softmax(dim=-1)
def forward(self, query, key):
if not self.diff_routing:
query, key = query.detach(), key.detach()
query_hat, key_hat = self.emb(query), self.emb(key)
attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
r_weight = self.routing_act(topk_attn_logit)
return r_weight, topk_index
class KVGather(nn.Module):
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard']
self.mul_weight = mul_weight
def forward(self, r_idx, r_weight, kv):
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
dim=2,
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv))
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv
elif self.mul_weight == 'hard':
raise NotImplementedError