Torch-Pruning is a PyTorch library designed for structured model pruning. It leverages DepGraph (Dependency Graph) to automatically identify and manage inter-layer dependencies, ensuring that pruning operations preserve model integrity during forward passes.
- Installation
Clone the repository and install in development mode:
git clone https://github.com/VainF/Torch-Pruning
cd Torch-Pruning
pip install -r requirements.txt
pip install -e .
Verify installation:
python -c "import torch_pruning"
- Understanding DepGraph
Directly pruning individual layers without considering dependencies (e.g., removing output channels from a Conv2d layer) breaks compatibility with subsequent layers like BatchNorm, causing runtime errors.
Torch-Pruning solves this by building a dependency graph that captures all affected layers when a channel is pruned. The correct workflow is:
- Build the dependancy graph using a sample input.
- Define a pruning group based on a target layer and indices.
- Validate and apply the pruning group.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
DG = tp.DependencyGraph().build_dependency(
model, example_inputs=torch.randn(1, 3, 224, 224)
)
# Prune channels [2, 6, 9] from the first conv layer
group = DG.get_pruning_group(
model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9]
)
if DG.check_pruning_group(group):
group.prune()
# Save the modified model structure
torch.save(model, 'pruned_model.pth')
loaded_model = torch.load('pruned_model.pth')
The printed pruning group shows all dependent layers—Conv, BN, ReLU, skip connections, and downstream convolutions—that must be pruned together to maintain consistency.
- Saving and Loading Pruned Models
Since pruning alters model architecture, saving only state_dict is insufficient. Two approaches are supported:
Option 1: Save Full Model Object
torch.save(model, 'model.pth') # includes structure + weights
model = torch.load('model.pth')
Opsion 2: Use Torch-Pruning’s State Utilities
# Save
state = tp.state_dict(model)
torch.save(state, 'pruned.pth')
# Load into a fresh model
new_model = resnet18().eval()
tp.load_state_dict(new_model, torch.load('pruned.pth'))
- Practical Pruning Examples
4.1 Global vs Uniform Pruning
Uniform pruning applies the same sparsity ratio to all layers (less efficient):
pruner = tp.pruner.MagnitudePruner(
model, example_inputs,
importance=tp.importance.TaylorImportance(),
ch_sparsity=0.5,
iterative_steps=5,
ignored_layers=[model.fc]
)
Global pruning ranks all channels across the network and prunes the least important ones globally (more effective):
pruner = tp.pruner.MagnitudePruner(
model, example_inputs,
pruning_ratio=0.5,
global_pruning=True,
ignored_layers=[model.fc]
)
4.2 Advanced Pruning Options
- Layer-specific ratios: Use
pruning_ratio_dictto assign custom sparsity per layer. - Channel rounding: Set
round_to=16for GPU-friendly channel counts. - Grouped convolutions: Specify
channel_groups={layer: 8}for depthwise or grouped convs. - Extra parameters: Pass learnable non-module parameters via
unwrapped_parameters. - Root modules: Limit pruning to specific layer types using
root_module_types=[nn.Conv2d, nn.Linear].
4.3 Sparse Training Integration
For methods like BNScalePruner, integrate regularization into training:
for epoch in range(epochs):
pruner.update_regularizer()
for data, target in loader:
optimizer.zero_grad()
loss = criterion(model(data), target)
loss.backward()
pruner.regularize(model) # inject sparsity gradients
optimizer.step()
- Real-World Applications
5.1 Vision Transformers (via timm)
Handles attention heads by tracking num_heads and adjusting qkv projections:
num_heads = {}
for m in model.modules():
if hasattr(m, 'num_heads') and hasattr(m, 'qkv'):
num_heads[m.qkv] = m.num_heads
pruner = tp.pruner.MetaPruner(
model, example_inputs,
num_heads=num_heads,
ignored_layers=[model.head]
)
pruner.step()
# Update head count post-pruning
5.2 LLM Pruning (e.g., LLaMA)
Prunes attention heads and MLP intermediate dimensions while updating config attributes:
num_heads = {}
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
# ... similarly for k_proj, v_proj
pruner = tp.pruner.MagnitudePruner(
model, inputs,
num_heads=num_heads,
prune_num_heads=True,
head_pruning_ratio=0.3,
ignored_layers=[model.lm_head]
)
pruner.step()
# Update model config
model.config.num_attention_heads = int(original * (1 - 0.3))
5.3 YOLO Models
Supports YOLOv5/v7/v8 with special handling for C2f blocks containing split operations. Custom adaptation may be needed to make such modules prunable.
- Importance Criteria and Performence Impact
Available importance metrics include:
MagnitudeImportanceTaylorImportance(often best for accuracy retention)GroupNormImportanceBNScaleImportance
Empirical studies (e.g., DepGraph, LLM-Pruner) show that up to 50% channel pruning often incurs minimal accuracy loss, especially when combined with fine-tuning. In some cases, 6× speedup is achievable without performance degradation.