Gradient Reparameterization-Based RepOptimizer: Core Principles and Implementation Details

Neural network architecture design encodes domain prior knowledge into model structures. For example, residual connections that model feature transformation as (y = f(x) + x) deliver better performance than plain (y=f(x)) mappings, which ResNet implements via shortcut paths. While architectural design has continuously evolved to integrate latest empirical findings, training pipelines almost universally rely on generic model-agnostic optimizers such as SGD and AdamW that do not incorporate any model-specific prior knowledge.

The gradient reparameterization technique modifies parameter gradients using model-specific hyperparameters before weight update steps, integrating structural prior knowledge directly into the optimization process without adding extra trainable parameters or runtime statistics during training. This approach powers a new category of reparameterized optimizers called RepOptimizers, which enable plain model architectures to match or exceed the performance of heavily engineered complex structures when trained with appropriate prior encoding.

Compared to structural reparameterization approaches such as RepVGG that require auxiliary multi-branch structures during training and post-training conversion for inference, RepOptimizers introduce no additional forward/backward compute or memory overhead during training. This efficiency makes RepOptimizer-trained models suitable for resource-constrained environments and rapid iterasion workflows. RepOptimizer-trained plain models also eliminate quantization performance degradation caused by post-training structural fusion, delivering consistent accuracy under INT8 post-training quantization (PTQ) pipelines.

RepOptimizer design starts with defining the structural priorr to encode. For the RepOpt-VGG plain architecture, the target prior is that multi-branch linear addition with constant per-branch scaling factors improves model performance, which is the same prior that underpins ResNet and RepVGG designs.

This prior is formalized as Constant-Scale Linear Addition (CSLA) blocks, where each branch contains exactly one linear trainable operator and an optional constant scaling factor, with no trainable non-linearities such as batch normalization or dropout within branches. Training a CSLA block with standard SGD is mathematically equivalent to training a single plain operator with gradient scaled by a constant factor derived from the CSLA branch scaling coefficients, a concrete implementation of gradient reparameterization. The scaling factor applied to gradients is referred to as Grad Mult.

The equivalence is proven via mathematical induction:

  1. Initialization condition: The weight of the plain operator (W'^{(0)}) is initialized as the weighted sum of CSLA branch weights, (W'^{(0)} = \alpha_A W^{(A)(0)} + \alpha_B W^{(B)(0)}), where (\alpha_A, \alpha_B) are the constant scaling factors of the two CSLA branches, and (W^{(A)}, W^{(B)}) are the weights of the two linear operators in the CSLA block. This ensures the initial output of the CSLA block and the plain operator are identical for any input.
  2. Iteration update: For training iteration (i), assume the equivalence (W'^{(i)} = \alpha_A W^{(A)(i)} + \alpha_B W^{(B)(i)}) holds. The gradient of the loss (L) with respect to (W') is (\frac{\partial L}{\partial W'} = \alpha_A \frac{\partial L}{\partial W^{(A)}} + \alpha_B \frac{\partial L}{\partial W^{(B)}}). Updating (W') with gradient scaled by (\alpha_A^2 + \alpha_B^2) ensures the equivalence holds for iteration (i+1).

For CSLA blocks with mixed kernel sizes (3x3 convolution, 1x1 convolution, identity path as used in RepVGG-style blocks), Grad Mult is a 4D tensor matching the shape of the plain 3x3 convolution kernel. For channel-wise scaling factors (s) (3x3 branch), (t) (1x1 branch), and identity path scaling, the Grad Mult tensor is constructed by:

  • Multiplying all positions of the tensor by (s^2)
  • Adding (t^2) to the center 1x1 position of each kernel (where 1x1 branch outputs are aligned)
  • Adding 1.0 to the diagonal elements of the center 1x1 position for blocks with identity paths (matching identity mapping alignment)

CSLA blocks are only conceptual constructs for defining RepOptimizer behavior, and are never actually instantiated during target model training.

Hyper-Search for Scaling Factors

The constant branch scaling factors are obtained via a Hyper-Search (HS) procedure:

  1. Construct an auxiliary HS model by replacing the constant scaling factors in the conceptual CSLA block with trainable parameters
  2. Train the auxiliary model on a small proxy dataset (e.g., CIFAR-100)
  3. Extract the final values of the trainable scaling factors as the constant scaling factors for Grad Mult calculation

The HS procedure is model-specific but dataset-agnostic: scaling factors searched on small proxy datasets transfer directly to large target datasets such as ImageNet without performance degradation.

RepOptimizer Training Workflow

  1. Extract scaling factors from the trained HS auxiliary model
  2. Construct the Grad Mult tensor for each parameter of the target plain model
  3. Initialize the target model weights as the weighted sum of initial CSLA branch weights aligned with the Grad Mult scaling factors
  4. During each training iteration, multiply the parameter gradients by the corresponding Grad Mult tensor before executing the optimizer update step

Experimental Findings

Scaling factors in the HS model are initialized as (\sqrt{\frac{2}{l}}) where (l) is the block depth, to encourage deeper layers to behave as identity mappings during early training and stabilize convergence.

RepOpt-VGG, a plain VGG-style architecture trained with RepOptimizer, delivers equivalent top-1 accuracy to RepVGG on ImageNet classification, with 1.8x higher training throughput and 30% lower memory footprint due to the elimination of multi-branch training structures. The reduced memory footprint allows larger batch sizes, further improving both training speed and final model accuracy.

INT8 PTQ of RepVGG leads to over 20% accuracy drop on ImageNet due to skewed parameter distributions introduced by post-training branch fusion, while RepOpt-VGG suffers only 2.5% accuracy drop under the same quantization pipeline, as no structural conversion is required for inference.

Scaling factors searched on CIFAR-100 deliver identical RepOpt-VGG performance on ImageNet as factors searched direct on ImageNet, confirming the dataset-agnostic property of RepOptimizer hyperparameters. Lowering the downsampling rate of the HS model to improve CIFAR-100 validation accuracy leads to degraded RepOpt-VGG performance on ImageNet, confirming the model-specific property of the searched scaling factors.

Implementation Details

The implementation follows three operational modes: hyper-search (HS), CSLA validation, and target model training.

CSLA Convolution Block

Used for HS and CSLA validation, this block implements the multi-branch linear addition structure, with scaling factors set to trainable in HS mode and frozen to constant values in CSLA mode. The identity branch is only included when input and output channel counts match and stride is 1.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import SGD

class ScaleLayer(nn.Module):
    def __init__(self, num_features, scale_init=1.0, use_bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_features) * scale_init)
        self.bias = nn.Parameter(torch.zeros(num_features)) if use_bias else None
    
    def forward(self, x):
        out = x * self.weight.view(1, -1, 1, 1)
        if self.bias is not None:
            out += self.bias.view(1, -1, 1, 1)
        return out

class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class CSLAConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, use_se=False, freeze_scales=False, scale_init_val=None):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.stride = stride
        self.relu = nn.ReLU(inplace=True)
        
        # 3x3 convolution branch
        self.conv3x3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.scale3x3 = ScaleLayer(out_ch, scale_init=scale_init_val, use_bias=False)
        
        # 1x1 convolution branch
        self.conv1x1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False)
        self.scale1x1 = ScaleLayer(out_ch, scale_init=scale_init_val, use_bias=False)
        
        # Identity branch (if applicable)
        if in_ch == out_ch and stride == 1:
            self.scale_identity = ScaleLayer(out_ch, scale_init=1.0, use_bias=False)
        
        self.bn = nn.BatchNorm2d(out_ch)
        self.se = SEBlock(out_ch) if use_se else nn.Identity()
        
        # Freeze scales for CSLA mode
        if freeze_scales:
            self.scale3x3.weight.requires_grad = False
            self.scale1x1.weight.requires_grad = False
    
    def forward(self, x):
        out = self.scale3x3(self.conv3x3(x)) + self.scale1x1(self.conv1x1(x))
        if hasattr(self, 'scale_identity'):
            out += self.scale_identity(x)
        out = self.se(self.relu(self.bn(out)))
        return out

Target Plain Convolution Block

Used for target model training, this block implements the plain 3x3 convolution structure with no auxiliary branches.

class PlainConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, use_se=False):
        super().__init__()
        self.conv3x3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.se = SEBlock(out_ch) if use_se else nn.Identity()
    
    def forward(self, x):
        return self.se(self.relu(self.bn(self.conv3x3(x))))

Scaling Factor Extraction

Extract trained scaling factors from the HS auxiliary model checkpoint to construct Grad Mult tensors.

def fetch_trained_scaling_factors(hs_model):
    block_list = [m for m in hs_model.modules() if isinstance(m, CSLAConvBlock)]
    scale_list = []
    for block in block_list:
        s3x3 = block.scale3x3.weight.detach().cpu()
        s1x1 = block.scale1x1.weight.detach().cpu()
        if hasattr(block, 'scale_identity'):
            sid = block.scale_identity.weight.detach().cpu()
            scale_list.append((sid, s1x1, s3x3))
        else:
            scale_list.append((s1x1, s3x3))
    return scale_list

def load_scales_from_checkpoint(hs_model_cls, ckpt_path, num_blocks, width_mult):
    hs_model = hs_model_cls(num_blocks=num_blocks, width_multiplier=width_mult, num_classes=100, mode='hs')
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = ckpt.get('model', ckpt.get('state_dict', ckpt))
    # Remove classification head weights
    for k in list(state_dict.keys()):
        if k.startswith('fc.'):
            del state_dict[k]
    hs_model.load_state_dict(state_dict, strict=False)
    return fetch_trained_scaling_factors(hs_model)

RepOptimizer Handler and Custom SGD

The handler implements weight initialization and Grad Mult tensor generation, while the custom SGD applies Grad Mult scaling during the update step.

class RepOptVGGHandler:
    def __init__(self, target_model, scale_list, update_rule='sgd', reinit_weights=True):
        self.target_blocks = [m for m in target_model.modules() if isinstance(m, PlainConvBlock)]
        self.conv_params = [b.conv3x3.weight for b in self.target_blocks]
        self.scale_list = scale_list
        self.update_rule = update_rule
        self.reinit_weights = reinit_weights
        self.power = 2 if update_rule == 'sgd' else 1
    
    def initialize_target_weights(self):
        if not self.reinit_weights:
            return
        for scales, conv3x3 in zip(self.scale_list, self.conv_params):
            out_ch, in_ch, _, _ = conv3x3.shape
            # Initialize 3x3 branch component
            init_tensor = conv3x3.data * scales[-1].view(-1, 1, 1, 1)
            # Add 1x1 branch component
            conv1x1_init = nn.Conv2d(in_ch, out_ch, 1).weight.data
            init_tensor += F.pad(conv1x1_init, [1,1,1,1]) * scales[-2].view(-1,1,1,1)
            # Add identity branch component if present
            if len(scales) == 3:
                id_tensor = torch.eye(out_ch).view(out_ch, out_ch, 1, 1)
                init_tensor += F.pad(id_tensor * scales[0].view(-1,1,1,1), [1,1,1,1])
            conv3x3.data = init_tensor
    
    def generate_grad_multipliers(self):
        grad_mult_dict = {}
        for scales, conv_param in zip(self.scale_list, self.conv_params):
            mask = torch.ones_like(conv_param) * (scales[-1] ** self.power).view(-1,1,1,1)
            # Add 1x1 branch contribution to kernel center
            mask[:, :, 1:2, 1:2] += (scales[-2] ** self.power).view(-1,1,1,1)
            # Add identity branch contribution if present
            if len(scales) == 3:
                ch_indices = torch.arange(conv_param.shape[0])
                mask[ch_indices, ch_indices, 1:2, 1:2] += 1.0
            grad_mult_dict[conv_param] = mask.cuda() if conv_param.is_cuda else mask
        return grad_mult_dict

class RepOptSGD(SGD):
    def __init__(self, grad_mult_dict, params, lr, momentum=0.9, weight_decay=4e-5, nesterov=True):
        super().__init__(params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
        self.grad_mult_dict = grad_mult_dict
    
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            wd = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                # Apply gradient multiplication if parameter has a Grad Mult entry
                grad = p.grad.data * self.grad_mult_dict[p] if p in self.grad_mult_dict else p.grad.data
                
                if wd != 0:
                    grad.add_(p.data, alpha=wd)
                if momentum != 0:
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        buf = state['momentum_buffer'] = torch.clone(grad).detach()
                    else:
                        buf = state['momentum_buffer']
                        buf.mul_(momentum).add_(grad, alpha=1 - dampening)
                    if nesterov:
                        grad = grad + buf * momentum
                    else:
                        grad = buf
                p.data.add_(grad, alpha=-lr)
        return loss

Optimizer Construction Workflow

def build_repopt_sgd(target_model, scale_list, lr, momentum=0.9, weight_decay=4e-5):
    handler = RepOptVGGHandler(target_model, scale_list, reinit_weights=True)
    handler.initialize_target_weights()
    # Apply weight decay only to convolution weights, exclude batch norm and bias
    param_groups = []
    for n, p in target_model.named_parameters():
        if len(p.shape) == 4:
            param_groups.append({'params': p, 'weight_decay': weight_decay})
        else:
            param_groups.append({'params': p, 'weight_decay': 0.0})
    grad_mults = handler.generate_grad_multipliers()
    return RepOptSGD(grad_mults, param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay)

Tags: Deep Learning Optimizer ICLR 2023 Model Reparameterization Computer Vision

Posted on Sat, 30 May 2026 21:45:51 +0000 by sonehs