MobileFormer: Efficient Hybrid Architecture for Local-Global Feature Fusion

MobileFormer introduces a novel architecture that synergistically combines the strengths of convolutional neural networks (CNNs) and Transformers to achieve high efficiency with minimal computational overhead. By leveraging a lightweight bidirectional bridge between a mobile backbone and a compact Transformer, it enables effective exchange of local and global contextual information without the computational burden of traditional vision Transformers.

Architecture Overview

MobileFormer operates in parallel: one branch processes input images using a MobileNet-style CNN to extract spatially localized features, while the other uses a small set of learnable global tokens—typically fewer than six—to represent global context. Unlike Vision Transformers (ViT), which divide images into hundreds of patches, MobileFormer’s global tokens are initialized randomly and remain fixed in number, drastically reducing attention computation.

The two branches are connected via two lightweight cross-attention mechanisms: one propagating local features to global tokens (Mobile → Former), and the other integrating global context back into local features (Former → Mobile). This bidirectional flow ensures that spatial details inform global understanding, and global semantics refine local representations—all with minimal parameter and FLOP overhead.

Lightweight Cross-Attention Mechanism

Standard attention requires query, key, and value projections for both modalities. MobileFormer optimizes this by eliminating redundant projections:

In Mobile → Former: Only the query projection is retained in the Transformer branch; key and value are directly taken from the CNN’s feature map without transformation. In Former → Mobile: Only the key and value projections are retained in the Transformer branch; the CNN’s features act as queries without linear mapping.

This design reduces computational cost while preserving information flow. The attention computation for Mobile → Former is formulated as:

$$ A_{X \to Z} = \text{Concat}\left[\text{Attn}\left(\tilde{z}_i W_i^Q, \tilde{x}_i, \tilde{x}i\right)\right]{i=1}^h W^O $$

where $ \tilde{x}_i \in \mathbb{R}^{HW \times d/h} $ are split channels from the CNN feature map $ X $, and $ \tilde{z}_i \in \mathbb{R}^{M \times d/h} $ are split global tokens. Query matrix $ W_i^Q $ resides in the Transformer, while key and value are unprojected local features.

Conversely, for Former → Mobile:

$$ A_{Z \to X} = \text{Concat}\left[\text{Attn}\left(\tilde{x}_i, \tilde{z}_i W_i^K, \tilde{z}i W_i^V\right)\right]{i=1}^h $$

Here, $ W_i^K $ and $ W_i^V $ are learned projections from global tokens, while local features serve as unprojected queries.

Mobile Sub-Block with Dynamic Activation

The CNN branch employs inverted bottleneck blocks with depthwise separable convolutions. To adaptively modulate activations based on global context, a dynamic ReLU is introduced. Instead of using fixed parameters, the activation thresholds are generated by two small MLPs conditioned on the first global token $ z_1 $:

class Mobile(nn.Module):
    def __init__(self, in_ch, exp_ch, out_ch, token_dim, stride=1, kernel_size=3, k=2):
        super().__init__()
        self.stride = stride
        self.in_ch, self.exp_ch, self.out_ch = in_ch, exp_ch, out_ch
        self.token_dim = token_dim

        if stride == 2:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_ch, exp_ch, kernel_size=3, stride=2, padding=1, groups=in_ch),
                nn.BatchNorm2d(exp_ch),
                nn.ReLU6(inplace=True)
            )
            self.conv1 = nn.Conv2d(exp_ch, in_ch, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, exp_ch, 1)
        
        self.bn1 = nn.BatchNorm2d(in_ch if stride == 2 else exp_ch)
        self.act1 = DynamicReLU(in_ch if stride == 2 else exp_ch, token_dim, k=k)
        
        self.conv2 = nn.Conv2d(exp_ch if stride == 2 else exp_ch, exp_ch, kernel_size, padding=1, groups=exp_ch)
        self.bn2 = nn.BatchNorm2d(exp_ch)
        self.act2 = DynamicReLU(exp_ch, token_dim, k=k)
        
        self.conv3 = nn.Conv2d(exp_ch, out_ch, 1)
        self.bn3 = nn.BatchNorm2d(out_ch)

    def forward(self, x, global_token):
        if self.stride == 2:
            x = self.downsample(x)
        x = self.bn1(self.conv1(x))
        x = self.act1(x, global_token[:, 0, :])  # Use first token to modulate
        x = self.bn2(self.conv2(x))
        x = self.act2(x, global_token[:, 0, :])
        return self.bn3(self.conv3(x))

Former Sub-Block with Reduced Expansion

The Transformer branch uses a standard MHA + FFN structure but reduces the expansion ratio from 4 to 2 to cut down on parameters. Layer normalization is applied post-attention and post-MLP. Key operations are optimized for small token counts (M ≤ 6), making attention computation negligible compared to the CNN branch.

class Former(nn.Module):
    def __init__(self, heads, dim, expand_ratio=2):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.head_dim = dim // heads
        self.eps = 1e-8

        self.qkv_proj = nn.Linear(dim, dim * 3)
        self.to_heads = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, expand_ratio * dim),
            nn.GELU(),
            nn.Linear(expand_ratio * dim, dim)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, tokens):
        qkv = self.qkv_proj(tokens).chunk(3, dim=-1)
        q, k, v = map(lambda x: rearrange(x, 'b n (h d) -> b h n d', h=self.heads), qkv)

        attn = torch.einsum('b h i d, b h j d -> b h i j', q, k) / (self.head_dim ** 0.5 + self.eps)
        attn = F.softmax(attn, dim=-1)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.out_proj(out)
        out = self.norm1(out + tokens)

        mlp_out = self.mlp(out)
        return self.norm2(mlp_out + out)

Bidirectional Cross-Attention Layers

The cross-attention layers are designed for minimal overehad:

Mobile → Former (local to global):

class MobileToFormer(nn.Module):
    def __init__(self, token_dim, feature_dim):
        super().__init__()
        self.query_proj = nn.Linear(token_dim, feature_dim)
        self.output_proj = nn.Linear(feature_dim, token_dim)

    def forward(self, local_features, global_tokens):
        b, c, h, w = local_features.shape
        local_flat = rearrange(local_features, 'b c h w -> b (h w) c')

        q = self.query_proj(global_tokens)
        scores = torch.einsum('b n c, b m c -> b n m', q, local_flat) * (c ** -0.5)
        attn = F.softmax(scores, dim=-1)

        fused = torch.einsum('b n m, b m c -> b n c', attn, local_flat)
        return self.output_proj(fused) + global_tokens

Former → Mobile (global to local):

class FormerToMobile(nn.Module):
    def __init__(self, token_dim, feature_dim):
        super().__init__()
        self.kv_proj = nn.Linear(token_dim, 2 * feature_dim)

    def forward(self, local_features, global_tokens):
        b, c, h, w = local_features.shape
        local_flat = rearrange(local_features, 'b c h w -> b (h w) c')

        kv = self.kv_proj(global_tokens).chunk(2, dim=-1)
        k, v = kv[0], kv[1]

        scores = torch.einsum('b n c, b m c -> b n m', local_flat, k)
        attn = F.softmax(scores, dim=-1)

        aggregated = torch.einsum('b n m, b m c -> b n c', attn, v)
        return rearrange(aggregated, 'b (h w) c -> b c h w', h=h, w=w) + local_features

Full Block and Network Design

Each MobileFormer block stacks the four components: Mobile sub-block, Former sub-block, Mobile→Former, and Former→Mobile. The input to each block is a local feature map and a set of global tokens; the output is an updated feature map and refined tokens.

class MobileFormerBlock(nn.Module):
    def __init__(self, in_ch, exp_ch, out_ch, token_dim, stride=1, heads=8, expand_ratio=2):
        super().__init__()
        self.mobile = Mobile(in_ch, exp_ch, out_ch, token_dim, stride=stride)
        self.former = Former(heads, token_dim, expand_ratio)
        self.m2f = MobileToFormer(token_dim, in_ch)
        self.f2m = FormerToMobile(token_dim, out_ch)

    def forward(self, x, z):
        z_updated = self.m2f(x, z)
        z_out = self.former(z_updated)
        x_hidden = self.mobile(x, z_out)
        x_out = self.f2m(x_hidden, z_out)
        return x_out, z_out

The full network begins with a 3×3 convolutional stem, followed by 11 stacked MobileFormer blocks. All blocks maintain 6 global tokens of dimension 192. Classification is performed by concatenating the final global tokens with the averaged local faetures, followed by two fully connected layers with h-swish activation. The model scales across seven variants, ranging from 26M to 508M FLOPs, adjusting width and depth while preserving architecture.

Tags: MobileFormer Transformer CNN hybrid-network lightweight-attention

Posted on Wed, 20 May 2026 20:19:31 +0000 by n00854180t