Getting Started with Torch-Pruning for Structured Model Pruning

Torch-Pruning is a PyTorch library designed for structured model pruning. It leverages DepGraph (Dependency Graph) to automatically identify and manage inter-layer dependencies, ensuring that pruning operations preserve model integrity during forward passes.

  1. Installation

Clone the repository and install in development mode:

git clone https://github.com/VainF/Torch-Pruning
cd Torch-Pruning
pip install -r requirements.txt
pip install -e .

Verify installation:

python -c "import torch_pruning"
  1. Understanding DepGraph

Directly pruning individual layers without considering dependencies (e.g., removing output channels from a Conv2d layer) breaks compatibility with subsequent layers like BatchNorm, causing runtime errors.

Torch-Pruning solves this by building a dependency graph that captures all affected layers when a channel is pruned. The correct workflow is:

  1. Build the dependancy graph using a sample input.
  2. Define a pruning group based on a target layer and indices.
  3. Validate and apply the pruning group.
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
DG = tp.DependencyGraph().build_dependency(
    model, example_inputs=torch.randn(1, 3, 224, 224)
)

# Prune channels [2, 6, 9] from the first conv layer
group = DG.get_pruning_group(
    model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9]
)

if DG.check_pruning_group(group):
    group.prune()

# Save the modified model structure
torch.save(model, 'pruned_model.pth')
loaded_model = torch.load('pruned_model.pth')

The printed pruning group shows all dependent layers—Conv, BN, ReLU, skip connections, and downstream convolutions—that must be pruned together to maintain consistency.

  1. Saving and Loading Pruned Models

Since pruning alters model architecture, saving only state_dict is insufficient. Two approaches are supported:

Option 1: Save Full Model Object

torch.save(model, 'model.pth')  # includes structure + weights
model = torch.load('model.pth')

Opsion 2: Use Torch-Pruning’s State Utilities

# Save
state = tp.state_dict(model)
torch.save(state, 'pruned.pth')

# Load into a fresh model
new_model = resnet18().eval()
tp.load_state_dict(new_model, torch.load('pruned.pth'))

  1. Practical Pruning Examples

4.1 Global vs Uniform Pruning

Uniform pruning applies the same sparsity ratio to all layers (less efficient):

pruner = tp.pruner.MagnitudePruner(
    model, example_inputs,
    importance=tp.importance.TaylorImportance(),
    ch_sparsity=0.5,
    iterative_steps=5,
    ignored_layers=[model.fc]
)

Global pruning ranks all channels across the network and prunes the least important ones globally (more effective):

pruner = tp.pruner.MagnitudePruner(
    model, example_inputs,
    pruning_ratio=0.5,
    global_pruning=True,
    ignored_layers=[model.fc]
)

4.2 Advanced Pruning Options

  • Layer-specific ratios: Use pruning_ratio_dict to assign custom sparsity per layer.
  • Channel rounding: Set round_to=16 for GPU-friendly channel counts.
  • Grouped convolutions: Specify channel_groups={layer: 8} for depthwise or grouped convs.
  • Extra parameters: Pass learnable non-module parameters via unwrapped_parameters.
  • Root modules: Limit pruning to specific layer types using root_module_types=[nn.Conv2d, nn.Linear].

4.3 Sparse Training Integration

For methods like BNScalePruner, integrate regularization into training:

for epoch in range(epochs):
    pruner.update_regularizer()
    for data, target in loader:
        optimizer.zero_grad()
        loss = criterion(model(data), target)
        loss.backward()
        pruner.regularize(model)  # inject sparsity gradients
        optimizer.step()

  1. Real-World Applications

5.1 Vision Transformers (via timm)

Handles attention heads by tracking num_heads and adjusting qkv projections:

num_heads = {}
for m in model.modules():
    if hasattr(m, 'num_heads') and hasattr(m, 'qkv'):
        num_heads[m.qkv] = m.num_heads

pruner = tp.pruner.MetaPruner(
    model, example_inputs,
    num_heads=num_heads,
    ignored_layers=[model.head]
)
pruner.step()
# Update head count post-pruning

5.2 LLM Pruning (e.g., LLaMA)

Prunes attention heads and MLP intermediate dimensions while updating config attributes:

num_heads = {}
for name, m in model.named_modules():
    if name.endswith("self_attn"):
        num_heads[m.q_proj] = model.config.num_attention_heads
        # ... similarly for k_proj, v_proj

pruner = tp.pruner.MagnitudePruner(
    model, inputs,
    num_heads=num_heads,
    prune_num_heads=True,
    head_pruning_ratio=0.3,
    ignored_layers=[model.lm_head]
)
pruner.step()

# Update model config
model.config.num_attention_heads = int(original * (1 - 0.3))

5.3 YOLO Models

Supports YOLOv5/v7/v8 with special handling for C2f blocks containing split operations. Custom adaptation may be needed to make such modules prunable.

  1. Importance Criteria and Performence Impact

Available importance metrics include:

  • MagnitudeImportance
  • TaylorImportance (often best for accuracy retention)
  • GroupNormImportance
  • BNScaleImportance

Empirical studies (e.g., DepGraph, LLM-Pruner) show that up to 50% channel pruning often incurs minimal accuracy loss, especially when combined with fine-tuning. In some cases, 6× speedup is achievable without performance degradation.

Tags: pytorch model-pruning structured-pruning Torch-Pruning DepGraph

Posted on Sat, 27 Jun 2026 16:05:36 +0000 by chipev