Various Attention Mechanisms for YOLO Series: SE, A2-Nets, BAM, and BiFormer

Attention mechanisms have significantly improved the performance of deep learning models in computer vision tasks. This article provides an overview of several popular attention modules that can be easily integrated into object detection models like YOLOv5, YOLOv7, YOLOv8, YOLOv9, and YOLOv10.

SE

Paper: Squeeze-and-Excitation Networks Link: arXiv:1709.01507

The Squeeze-and-Excitation (SE) attention mechanism enhances network performance by learning a weight vector that re-weights feature channels.

Overview

The SE block consists of two main operations:

  1. Squeeze: Global average pooling compresses each channel's feature map into a single scalar, representing a global descriptor.
  2. Excitation: Two fully connected layers with a sigmoid activation produce a weight vector (0 to 1) indicating each channel's importance.

These weights are then multiplied with the original feature map to recalibrate channel responses.

Applications and Limitations

SE attention is widely used in image classification, object detection, and segmentation. Its computationally lightweight and effective for channel-wise feature recalibration. However, it only considers channel attention, not spatial attention, which may limit performance in tasks requiring spatial awareness.

Code

import torch
from torch import nn

class SEAttention(nn.Module):
    def __init__(self, channel=512, reduction=16):
        super(SEAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

if __name__ == '__main__':
    input_tensor = torch.randn(64, 512, 20, 20)
    se_attention = SEAttention(channel=512, reduction=8)
    output_tensor = se_attention(input_tensor)
    print(output_tensor.shape)

A2-Nets: Double Attention Networks

Paper: A2-Nets: Double Attention Networks Link: arXiv:1810.11579

A2-Nets capture and distribute global features using a double attention mechanism.

Core Idea

The network first gathers critical features from the entire spatial extent into a compact set using second-order attention pooling. It then adaptively distributes these features to each location using another attention mechanism.

Code

import torch
from torch import nn
from torch.nn import functional as F

class DoubleAttention(nn.Module):
    def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):
        super(DoubleAttention, self).__init__()
        self.in_channels = in_channels
        self.reconstruct = reconstruct
        self.c_m = c_m
        self.c_n = c_n
        self.convA = nn.Conv2d(in_channels, c_m, 1)
        self.convB = nn.Conv2d(in_channels, c_n, 1)
        self.convV = nn.Conv2d(in_channels, c_n, 1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)

    def forward(self, x):
        b, c, h, w = x.shape
        assert c == self.in_channels

        A = self.convA(x)  # b, c_m, h, w
        B = self.convB(x)  # b, c_n, h, w
        V = self.convV(x)  # b, c_n, h, w

        tmpA = A.view(b, self.c_m, -1)  # b, c_m, h*w
        attention_maps = F.softmax(B.view(b, self.c_n, -1), dim=-1)  # b, c_n, h*w
        attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=1)  # b, c_n, h*w

        global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1))  # b, c_m, c_n
        tmpZ = torch.bmm(global_descriptors, attention_vectors)  # b, c_m, h*w
        tmpZ = tmpZ.view(b, self.c_m, h, w)

        if self.reconstruct:
            tmpZ = self.conv_reconstruct(tmpZ)

        return tmpZ

if __name__ == '__main__':
    input_tensor = torch.randn(64, 512, 20, 20)
    double_attention = DoubleAttention(512)
    output_tensor = double_attention(input_tensor)
    print(output_tensor.shape)

BAM: Bottleneck Attention Module

Paper: BAM: Bottleneck Attention Module Link: arXiv:1807.06514

BAM combines channel and spatial attention to improve feature extraction.

Channel and Spatial Attention

  • Channel attention: Uses global average pooling followed by fully connected layers to learn channel importance.
  • Spatial attention: Uses dilated convolutions to compute a spatial attention map.

Both attention output are combined and applied to the input via a residual connection.

Code

import torch
from torch import nn

def autopad(kernel_size, padding=None, dilation=1):
    if dilation > 1:
        kernel_size = dilation * (kernel_size - 1) + 1 if isinstance(kernel_size, int) else [dilation * (x - 1) + 1 for x in kernel_size]
    if padding is None:
        padding = kernel_size // 2 if isinstance(kernel_size, int) else [x // 2 for x in kernel_size]
    return padding

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16, num_layers=3):
        super(ChannelAttention, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        gate_channels = [in_channels]
        gate_channels += [in_channels // reduction] * num_layers
        gate_channels += [in_channels]
        self.ca = nn.Sequential()
        self.ca.add_module('flatten', Flatten())
        for i in range(len(gate_channels) - 2):
            self.ca.add_module(f'fc_{i}', nn.Linear(gate_channels[i], gate_channels[i+1]))
            self.ca.add_module(f'bn_{i}', nn.BatchNorm1d(gate_channels[i+1]))
            self.ca.add_module(f'relu_{i}', nn.ReLU())
        self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))

    def forward(self, x):
        res = self.avgpool(x)
        res = self.ca(res)
        res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
        return res

class SpatialAttention(nn.Module):
    def __init__(self, in_channels, reduction=16, num_layers=3, dilation=2):
        super(SpatialAttention, self).__init__()
        self.sa = nn.Sequential()
        self.sa.add_module('conv_reduce', nn.Conv2d(kernel_size=1, in_channels=in_channels, out_channels=in_channels // reduction))
        self.sa.add_module('bn_reduce', nn.BatchNorm2d(in_channels // reduction))
        self.sa.add_module('relu_reduce', nn.ReLU())
        for i in range(num_layers):
            self.sa.add_module(f'conv_{i}', nn.Conv2d(kernel_size=3, in_channels=in_channels // reduction, out_channels=in_channels // reduction, padding=autopad(3, None, dilation), dilation=dilation))
            self.sa.add_module(f'bn_{i}', nn.BatchNorm2d(in_channels // reduction))
            self.sa.add_module(f'relu_{i}', nn.ReLU())
        self.sa.add_module('last_conv', nn.Conv2d(in_channels // reduction, 1, kernel_size=1))

    def forward(self, x):
        res = self.sa(x)
        res = res.expand_as(x)
        return res

class BAMBlock(nn.Module):
    def __init__(self, in_channels=512, reduction=16, dilation=2):
        super(BAMBlock, self).__init__()
        self.ca = ChannelAttention(in_channels=in_channels, reduction=reduction)
        self.sa = SpatialAttention(in_channels=in_channels, reduction=reduction, dilation=dilation)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        sa_out = self.sa(x)
        ca_out = self.ca(x)
        weight = self.sigmoid(sa_out + ca_out)
        out = (1 + weight) * x
        return out

if __name__ == '__main__':
    input_tensor = torch.randn(64, 512, 7, 7)
    bam = BAMBlock(in_channels=512, reduction=16, dilation=2)
    output_tensor = bam(input_tensor)
    print(output_tensor.shape)

BiFormer: Bi-Level Routing Attention

Paper: BiFormer: Vision Transformer with Bi-Level Routing Attention Link: arXiv:2303.08810

BiFormer introduces a bi-level routing attention mechanism for dynamic sparse attention in vision transformers.

Method

  • Region-level routing: A coarse-grained filtering removes irrelevant key-value pairs by constructing a region affinity graph and selecting the top-k connections.
  • Token-level attention: Fine-grained token-to-token attention is applied only within the union of remaining candidate regions.

This approach reduces computational cost while maintaining high performance.

Code

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class TopkRouting(nn.Module):
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5
        self.diff_routing = diff_routing
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        self.routing_act = nn.Softmax(dim=-1)

    def forward(self, query, key):
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key)
        attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
        r_weight = self.routing_act(topk_attn_logit)
        return r_weight, topk_index

class KVGather(nn.Module):
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard']
        self.mul_weight = mul_weight

    def forward(self, r_idx, r_weight, kv):
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
                                dim=2,
                                index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv))
        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv
        elif self.mul_weight == 'hard':
            raise NotImplementedError

Tags: attention mechanism YOLO Object Detection Deep Learning Computer Vision

Posted on Fri, 22 May 2026 19:06:16 +0000 by sheephat