Neural network architecture design encodes domain prior knowledge into model structures. For example, residual connections that model feature transformation as (y = f(x) + x) deliver better performance than plain (y=f(x)) mappings, which ResNet implements via shortcut paths. While architectural design has continuously evolved to integrate latest empirical findings, training pipelines almost universally rely on generic model-agnostic optimizers such as SGD and AdamW that do not incorporate any model-specific prior knowledge.
The gradient reparameterization technique modifies parameter gradients using model-specific hyperparameters before weight update steps, integrating structural prior knowledge directly into the optimization process without adding extra trainable parameters or runtime statistics during training. This approach powers a new category of reparameterized optimizers called RepOptimizers, which enable plain model architectures to match or exceed the performance of heavily engineered complex structures when trained with appropriate prior encoding.
Compared to structural reparameterization approaches such as RepVGG that require auxiliary multi-branch structures during training and post-training conversion for inference, RepOptimizers introduce no additional forward/backward compute or memory overhead during training. This efficiency makes RepOptimizer-trained models suitable for resource-constrained environments and rapid iterasion workflows. RepOptimizer-trained plain models also eliminate quantization performance degradation caused by post-training structural fusion, delivering consistent accuracy under INT8 post-training quantization (PTQ) pipelines.
RepOptimizer design starts with defining the structural priorr to encode. For the RepOpt-VGG plain architecture, the target prior is that multi-branch linear addition with constant per-branch scaling factors improves model performance, which is the same prior that underpins ResNet and RepVGG designs.
This prior is formalized as Constant-Scale Linear Addition (CSLA) blocks, where each branch contains exactly one linear trainable operator and an optional constant scaling factor, with no trainable non-linearities such as batch normalization or dropout within branches. Training a CSLA block with standard SGD is mathematically equivalent to training a single plain operator with gradient scaled by a constant factor derived from the CSLA branch scaling coefficients, a concrete implementation of gradient reparameterization. The scaling factor applied to gradients is referred to as Grad Mult.
The equivalence is proven via mathematical induction:
- Initialization condition: The weight of the plain operator (W'^{(0)}) is initialized as the weighted sum of CSLA branch weights, (W'^{(0)} = \alpha_A W^{(A)(0)} + \alpha_B W^{(B)(0)}), where (\alpha_A, \alpha_B) are the constant scaling factors of the two CSLA branches, and (W^{(A)}, W^{(B)}) are the weights of the two linear operators in the CSLA block. This ensures the initial output of the CSLA block and the plain operator are identical for any input.
- Iteration update: For training iteration (i), assume the equivalence (W'^{(i)} = \alpha_A W^{(A)(i)} + \alpha_B W^{(B)(i)}) holds. The gradient of the loss (L) with respect to (W') is (\frac{\partial L}{\partial W'} = \alpha_A \frac{\partial L}{\partial W^{(A)}} + \alpha_B \frac{\partial L}{\partial W^{(B)}}). Updating (W') with gradient scaled by (\alpha_A^2 + \alpha_B^2) ensures the equivalence holds for iteration (i+1).
For CSLA blocks with mixed kernel sizes (3x3 convolution, 1x1 convolution, identity path as used in RepVGG-style blocks), Grad Mult is a 4D tensor matching the shape of the plain 3x3 convolution kernel. For channel-wise scaling factors (s) (3x3 branch), (t) (1x1 branch), and identity path scaling, the Grad Mult tensor is constructed by:
- Multiplying all positions of the tensor by (s^2)
- Adding (t^2) to the center 1x1 position of each kernel (where 1x1 branch outputs are aligned)
- Adding 1.0 to the diagonal elements of the center 1x1 position for blocks with identity paths (matching identity mapping alignment)
CSLA blocks are only conceptual constructs for defining RepOptimizer behavior, and are never actually instantiated during target model training.
Hyper-Search for Scaling Factors
The constant branch scaling factors are obtained via a Hyper-Search (HS) procedure:
- Construct an auxiliary HS model by replacing the constant scaling factors in the conceptual CSLA block with trainable parameters
- Train the auxiliary model on a small proxy dataset (e.g., CIFAR-100)
- Extract the final values of the trainable scaling factors as the constant scaling factors for Grad Mult calculation
The HS procedure is model-specific but dataset-agnostic: scaling factors searched on small proxy datasets transfer directly to large target datasets such as ImageNet without performance degradation.
RepOptimizer Training Workflow
- Extract scaling factors from the trained HS auxiliary model
- Construct the Grad Mult tensor for each parameter of the target plain model
- Initialize the target model weights as the weighted sum of initial CSLA branch weights aligned with the Grad Mult scaling factors
- During each training iteration, multiply the parameter gradients by the corresponding Grad Mult tensor before executing the optimizer update step
Experimental Findings
Scaling factors in the HS model are initialized as (\sqrt{\frac{2}{l}}) where (l) is the block depth, to encourage deeper layers to behave as identity mappings during early training and stabilize convergence.
RepOpt-VGG, a plain VGG-style architecture trained with RepOptimizer, delivers equivalent top-1 accuracy to RepVGG on ImageNet classification, with 1.8x higher training throughput and 30% lower memory footprint due to the elimination of multi-branch training structures. The reduced memory footprint allows larger batch sizes, further improving both training speed and final model accuracy.
INT8 PTQ of RepVGG leads to over 20% accuracy drop on ImageNet due to skewed parameter distributions introduced by post-training branch fusion, while RepOpt-VGG suffers only 2.5% accuracy drop under the same quantization pipeline, as no structural conversion is required for inference.
Scaling factors searched on CIFAR-100 deliver identical RepOpt-VGG performance on ImageNet as factors searched direct on ImageNet, confirming the dataset-agnostic property of RepOptimizer hyperparameters. Lowering the downsampling rate of the HS model to improve CIFAR-100 validation accuracy leads to degraded RepOpt-VGG performance on ImageNet, confirming the model-specific property of the searched scaling factors.
Implementation Details
The implementation follows three operational modes: hyper-search (HS), CSLA validation, and target model training.
CSLA Convolution Block
Used for HS and CSLA validation, this block implements the multi-branch linear addition structure, with scaling factors set to trainable in HS mode and frozen to constant values in CSLA mode. The identity branch is only included when input and output channel counts match and stride is 1.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import SGD
class ScaleLayer(nn.Module):
def __init__(self, num_features, scale_init=1.0, use_bias=False):
super().__init__()
self.weight = nn.Parameter(torch.ones(num_features) * scale_init)
self.bias = nn.Parameter(torch.zeros(num_features)) if use_bias else None
def forward(self, x):
out = x * self.weight.view(1, -1, 1, 1)
if self.bias is not None:
out += self.bias.view(1, -1, 1, 1)
return out
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class CSLAConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, use_se=False, freeze_scales=False, scale_init_val=None):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.stride = stride
self.relu = nn.ReLU(inplace=True)
# 3x3 convolution branch
self.conv3x3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
self.scale3x3 = ScaleLayer(out_ch, scale_init=scale_init_val, use_bias=False)
# 1x1 convolution branch
self.conv1x1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False)
self.scale1x1 = ScaleLayer(out_ch, scale_init=scale_init_val, use_bias=False)
# Identity branch (if applicable)
if in_ch == out_ch and stride == 1:
self.scale_identity = ScaleLayer(out_ch, scale_init=1.0, use_bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.se = SEBlock(out_ch) if use_se else nn.Identity()
# Freeze scales for CSLA mode
if freeze_scales:
self.scale3x3.weight.requires_grad = False
self.scale1x1.weight.requires_grad = False
def forward(self, x):
out = self.scale3x3(self.conv3x3(x)) + self.scale1x1(self.conv1x1(x))
if hasattr(self, 'scale_identity'):
out += self.scale_identity(x)
out = self.se(self.relu(self.bn(out)))
return out
Target Plain Convolution Block
Used for target model training, this block implements the plain 3x3 convolution structure with no auxiliary branches.
class PlainConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, use_se=False):
super().__init__()
self.conv3x3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
self.se = SEBlock(out_ch) if use_se else nn.Identity()
def forward(self, x):
return self.se(self.relu(self.bn(self.conv3x3(x))))
Scaling Factor Extraction
Extract trained scaling factors from the HS auxiliary model checkpoint to construct Grad Mult tensors.
def fetch_trained_scaling_factors(hs_model):
block_list = [m for m in hs_model.modules() if isinstance(m, CSLAConvBlock)]
scale_list = []
for block in block_list:
s3x3 = block.scale3x3.weight.detach().cpu()
s1x1 = block.scale1x1.weight.detach().cpu()
if hasattr(block, 'scale_identity'):
sid = block.scale_identity.weight.detach().cpu()
scale_list.append((sid, s1x1, s3x3))
else:
scale_list.append((s1x1, s3x3))
return scale_list
def load_scales_from_checkpoint(hs_model_cls, ckpt_path, num_blocks, width_mult):
hs_model = hs_model_cls(num_blocks=num_blocks, width_multiplier=width_mult, num_classes=100, mode='hs')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt.get('model', ckpt.get('state_dict', ckpt))
# Remove classification head weights
for k in list(state_dict.keys()):
if k.startswith('fc.'):
del state_dict[k]
hs_model.load_state_dict(state_dict, strict=False)
return fetch_trained_scaling_factors(hs_model)
RepOptimizer Handler and Custom SGD
The handler implements weight initialization and Grad Mult tensor generation, while the custom SGD applies Grad Mult scaling during the update step.
class RepOptVGGHandler:
def __init__(self, target_model, scale_list, update_rule='sgd', reinit_weights=True):
self.target_blocks = [m for m in target_model.modules() if isinstance(m, PlainConvBlock)]
self.conv_params = [b.conv3x3.weight for b in self.target_blocks]
self.scale_list = scale_list
self.update_rule = update_rule
self.reinit_weights = reinit_weights
self.power = 2 if update_rule == 'sgd' else 1
def initialize_target_weights(self):
if not self.reinit_weights:
return
for scales, conv3x3 in zip(self.scale_list, self.conv_params):
out_ch, in_ch, _, _ = conv3x3.shape
# Initialize 3x3 branch component
init_tensor = conv3x3.data * scales[-1].view(-1, 1, 1, 1)
# Add 1x1 branch component
conv1x1_init = nn.Conv2d(in_ch, out_ch, 1).weight.data
init_tensor += F.pad(conv1x1_init, [1,1,1,1]) * scales[-2].view(-1,1,1,1)
# Add identity branch component if present
if len(scales) == 3:
id_tensor = torch.eye(out_ch).view(out_ch, out_ch, 1, 1)
init_tensor += F.pad(id_tensor * scales[0].view(-1,1,1,1), [1,1,1,1])
conv3x3.data = init_tensor
def generate_grad_multipliers(self):
grad_mult_dict = {}
for scales, conv_param in zip(self.scale_list, self.conv_params):
mask = torch.ones_like(conv_param) * (scales[-1] ** self.power).view(-1,1,1,1)
# Add 1x1 branch contribution to kernel center
mask[:, :, 1:2, 1:2] += (scales[-2] ** self.power).view(-1,1,1,1)
# Add identity branch contribution if present
if len(scales) == 3:
ch_indices = torch.arange(conv_param.shape[0])
mask[ch_indices, ch_indices, 1:2, 1:2] += 1.0
grad_mult_dict[conv_param] = mask.cuda() if conv_param.is_cuda else mask
return grad_mult_dict
class RepOptSGD(SGD):
def __init__(self, grad_mult_dict, params, lr, momentum=0.9, weight_decay=4e-5, nesterov=True):
super().__init__(params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
self.grad_mult_dict = grad_mult_dict
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
wd = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
for p in group['params']:
if p.grad is None:
continue
# Apply gradient multiplication if parameter has a Grad Mult entry
grad = p.grad.data * self.grad_mult_dict[p] if p in self.grad_mult_dict else p.grad.data
if wd != 0:
grad.add_(p.data, alpha=wd)
if momentum != 0:
state = self.state[p]
if 'momentum_buffer' not in state:
buf = state['momentum_buffer'] = torch.clone(grad).detach()
else:
buf = state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad + buf * momentum
else:
grad = buf
p.data.add_(grad, alpha=-lr)
return loss
Optimizer Construction Workflow
def build_repopt_sgd(target_model, scale_list, lr, momentum=0.9, weight_decay=4e-5):
handler = RepOptVGGHandler(target_model, scale_list, reinit_weights=True)
handler.initialize_target_weights()
# Apply weight decay only to convolution weights, exclude batch norm and bias
param_groups = []
for n, p in target_model.named_parameters():
if len(p.shape) == 4:
param_groups.append({'params': p, 'weight_decay': weight_decay})
else:
param_groups.append({'params': p, 'weight_decay': 0.0})
grad_mults = handler.generate_grad_multipliers()
return RepOptSGD(grad_mults, param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay)