MobileFormer introduces a novel architecture that synergistically combines the strengths of convolutional neural networks (CNNs) and Transformers to achieve high efficiency with minimal computational overhead. By leveraging a lightweight bidirectional bridge between a mobile backbone and a compact Transformer, it enables effective exchange of local and global contextual information without the computational burden of traditional vision Transformers.
Architecture Overview
MobileFormer operates in parallel: one branch processes input images using a MobileNet-style CNN to extract spatially localized features, while the other uses a small set of learnable global tokens—typically fewer than six—to represent global context. Unlike Vision Transformers (ViT), which divide images into hundreds of patches, MobileFormer’s global tokens are initialized randomly and remain fixed in number, drastically reducing attention computation.
The two branches are connected via two lightweight cross-attention mechanisms: one propagating local features to global tokens (Mobile → Former), and the other integrating global context back into local features (Former → Mobile). This bidirectional flow ensures that spatial details inform global understanding, and global semantics refine local representations—all with minimal parameter and FLOP overhead.
Lightweight Cross-Attention Mechanism
Standard attention requires query, key, and value projections for both modalities. MobileFormer optimizes this by eliminating redundant projections:
In Mobile → Former: Only the query projection is retained in the Transformer branch; key and value are directly taken from the CNN’s feature map without transformation. In Former → Mobile: Only the key and value projections are retained in the Transformer branch; the CNN’s features act as queries without linear mapping.
This design reduces computational cost while preserving information flow. The attention computation for Mobile → Former is formulated as:
$$ A_{X \to Z} = \text{Concat}\left[\text{Attn}\left(\tilde{z}_i W_i^Q, \tilde{x}_i, \tilde{x}i\right)\right]{i=1}^h W^O $$
where $ \tilde{x}_i \in \mathbb{R}^{HW \times d/h} $ are split channels from the CNN feature map $ X $, and $ \tilde{z}_i \in \mathbb{R}^{M \times d/h} $ are split global tokens. Query matrix $ W_i^Q $ resides in the Transformer, while key and value are unprojected local features.
Conversely, for Former → Mobile:
$$ A_{Z \to X} = \text{Concat}\left[\text{Attn}\left(\tilde{x}_i, \tilde{z}_i W_i^K, \tilde{z}i W_i^V\right)\right]{i=1}^h $$
Here, $ W_i^K $ and $ W_i^V $ are learned projections from global tokens, while local features serve as unprojected queries.
Mobile Sub-Block with Dynamic Activation
The CNN branch employs inverted bottleneck blocks with depthwise separable convolutions. To adaptively modulate activations based on global context, a dynamic ReLU is introduced. Instead of using fixed parameters, the activation thresholds are generated by two small MLPs conditioned on the first global token $ z_1 $:
class Mobile(nn.Module):
def __init__(self, in_ch, exp_ch, out_ch, token_dim, stride=1, kernel_size=3, k=2):
super().__init__()
self.stride = stride
self.in_ch, self.exp_ch, self.out_ch = in_ch, exp_ch, out_ch
self.token_dim = token_dim
if stride == 2:
self.downsample = nn.Sequential(
nn.Conv2d(in_ch, exp_ch, kernel_size=3, stride=2, padding=1, groups=in_ch),
nn.BatchNorm2d(exp_ch),
nn.ReLU6(inplace=True)
)
self.conv1 = nn.Conv2d(exp_ch, in_ch, 1)
else:
self.conv1 = nn.Conv2d(in_ch, exp_ch, 1)
self.bn1 = nn.BatchNorm2d(in_ch if stride == 2 else exp_ch)
self.act1 = DynamicReLU(in_ch if stride == 2 else exp_ch, token_dim, k=k)
self.conv2 = nn.Conv2d(exp_ch if stride == 2 else exp_ch, exp_ch, kernel_size, padding=1, groups=exp_ch)
self.bn2 = nn.BatchNorm2d(exp_ch)
self.act2 = DynamicReLU(exp_ch, token_dim, k=k)
self.conv3 = nn.Conv2d(exp_ch, out_ch, 1)
self.bn3 = nn.BatchNorm2d(out_ch)
def forward(self, x, global_token):
if self.stride == 2:
x = self.downsample(x)
x = self.bn1(self.conv1(x))
x = self.act1(x, global_token[:, 0, :]) # Use first token to modulate
x = self.bn2(self.conv2(x))
x = self.act2(x, global_token[:, 0, :])
return self.bn3(self.conv3(x))
Former Sub-Block with Reduced Expansion
The Transformer branch uses a standard MHA + FFN structure but reduces the expansion ratio from 4 to 2 to cut down on parameters. Layer normalization is applied post-attention and post-MLP. Key operations are optimized for small token counts (M ≤ 6), making attention computation negligible compared to the CNN branch.
class Former(nn.Module):
def __init__(self, heads, dim, expand_ratio=2):
super().__init__()
self.heads = heads
self.dim = dim
self.head_dim = dim // heads
self.eps = 1e-8
self.qkv_proj = nn.Linear(dim, dim * 3)
self.to_heads = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.norm1 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, expand_ratio * dim),
nn.GELU(),
nn.Linear(expand_ratio * dim, dim)
)
self.norm2 = nn.LayerNorm(dim)
def forward(self, tokens):
qkv = self.qkv_proj(tokens).chunk(3, dim=-1)
q, k, v = map(lambda x: rearrange(x, 'b n (h d) -> b h n d', h=self.heads), qkv)
attn = torch.einsum('b h i d, b h j d -> b h i j', q, k) / (self.head_dim ** 0.5 + self.eps)
attn = F.softmax(attn, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.out_proj(out)
out = self.norm1(out + tokens)
mlp_out = self.mlp(out)
return self.norm2(mlp_out + out)
Bidirectional Cross-Attention Layers
The cross-attention layers are designed for minimal overehad:
Mobile → Former (local to global):
class MobileToFormer(nn.Module):
def __init__(self, token_dim, feature_dim):
super().__init__()
self.query_proj = nn.Linear(token_dim, feature_dim)
self.output_proj = nn.Linear(feature_dim, token_dim)
def forward(self, local_features, global_tokens):
b, c, h, w = local_features.shape
local_flat = rearrange(local_features, 'b c h w -> b (h w) c')
q = self.query_proj(global_tokens)
scores = torch.einsum('b n c, b m c -> b n m', q, local_flat) * (c ** -0.5)
attn = F.softmax(scores, dim=-1)
fused = torch.einsum('b n m, b m c -> b n c', attn, local_flat)
return self.output_proj(fused) + global_tokens
Former → Mobile (global to local):
class FormerToMobile(nn.Module):
def __init__(self, token_dim, feature_dim):
super().__init__()
self.kv_proj = nn.Linear(token_dim, 2 * feature_dim)
def forward(self, local_features, global_tokens):
b, c, h, w = local_features.shape
local_flat = rearrange(local_features, 'b c h w -> b (h w) c')
kv = self.kv_proj(global_tokens).chunk(2, dim=-1)
k, v = kv[0], kv[1]
scores = torch.einsum('b n c, b m c -> b n m', local_flat, k)
attn = F.softmax(scores, dim=-1)
aggregated = torch.einsum('b n m, b m c -> b n c', attn, v)
return rearrange(aggregated, 'b (h w) c -> b c h w', h=h, w=w) + local_features
Full Block and Network Design
Each MobileFormer block stacks the four components: Mobile sub-block, Former sub-block, Mobile→Former, and Former→Mobile. The input to each block is a local feature map and a set of global tokens; the output is an updated feature map and refined tokens.
class MobileFormerBlock(nn.Module):
def __init__(self, in_ch, exp_ch, out_ch, token_dim, stride=1, heads=8, expand_ratio=2):
super().__init__()
self.mobile = Mobile(in_ch, exp_ch, out_ch, token_dim, stride=stride)
self.former = Former(heads, token_dim, expand_ratio)
self.m2f = MobileToFormer(token_dim, in_ch)
self.f2m = FormerToMobile(token_dim, out_ch)
def forward(self, x, z):
z_updated = self.m2f(x, z)
z_out = self.former(z_updated)
x_hidden = self.mobile(x, z_out)
x_out = self.f2m(x_hidden, z_out)
return x_out, z_out
The full network begins with a 3×3 convolutional stem, followed by 11 stacked MobileFormer blocks. All blocks maintain 6 global tokens of dimension 192. Classification is performed by concatenating the final global tokens with the averaged local faetures, followed by two fully connected layers with h-swish activation. The model scales across seven variants, ranging from 26M to 508M FLOPs, adjusting width and depth while preserving architecture.