Training large-scale neural networks often exceeds the computational and memory capacity of a single device, neecssitating distributed training strategies. Among these, model parallelism (MP) plays a crucial role by partitioning the model itself across multiple devices. Within MP, pipeline parallelism (PP) stands out as an effective technique that distributes different layers of a model across devices and overlaps computation through micro-batching.
Gpipe: Synchronous Pipeline Execution
Gpipe implements pipeline parallelism by splitting both the model and input batch into segments. The model is divided into consecutive stages, each assigned to a separate device. Input batches are further split into micro-batches, enabling pipelined execution: while one device processes the forward pass of micro-batch j, another may be handling the backward pass of micro-batch j−1.
This approach follows an F-then-B (Forward-then-Backward) schedule: all micro-batches complete their forward passes before any backward pass begins. Although this ensures gradient consistency, it introduces idle periods—known as bubbles—where devices wait for dependencies to resolve.
To mitigate memory pressure, Gpipe supports activation checkpointing: intermediate activations are recomputed during the backward pass instead of being stored, significantly reducing per-device memory usage.
PipeDream: Asynchronous 1F1B Scheduling
PipeDream improves hardware utilization by adopting a 1F1B (One Forward, One Backward) strategy. After computing the forward pass of a micro-batch, a stage immediately computes its corresponding backward pass. This interleaving reduces bubbles and keeps devices busy.
However, 1F1B introduces weight staleness: different micro-batches may use slightly different versions of model weights due to asynchronous updates. While this can affect convergence, empirical results show acceptable stability in practice.
An extension, 1F1B-RR (Round-Robin), handles stages with data parallelism by assigning micro-batches in a round-robin fashion across replicas within a stage, ensuring that forward and backward computations for a given micro-batch occur on the same worker.
Implementation Approaches
Manual CUDA Stream-Based Pipeline
A basic implementation in PyTorch manually splits inputs and pipelines computation across GPUs:
class PipelineResNet50(nn.Module):
def __init__(self, micro_batch_size=20):
super().__init__()
self.micro_batch_size = micro_batch_size
self.stage0 = nn.Sequential(...).to('cuda:0')
self.stage1 = nn.Sequential(...).to('cuda:1')
self.classifier = nn.Linear(...).to('cuda:1')
def forward(self, x):
chunks = list(x.split(self.micro_batch_size, dim=0))
if not chunks:
return torch.empty(0)
# Warm-up: first chunk through stage0
prev_out = self.stage0(chunks[0]).to('cuda:1')
outputs = []
# Pipeline loop
for chunk in chunks[1:]:
# Overlap: stage1 on prev_out + stage0 on next chunk
hidden = self.stage1(prev_out)
outputs.append(self.classifier(hidden.flatten(1)))
prev_out = self.stage0(chunk).to('cuda:1')
# Finalize last chunk
hidden = self.stage1(prev_out)
outputs.append(self.classifier(hidden.flatten(1)))
return torch.cat(outputs, dim=0)
This leverages PyTorch’s asynchronous CUDA execution but uses only default streams, limiting overlap between computation and inter-GPU transfers.
RPC-Based Distributed Pipeline
Using PyTorch’s RPC framework enables true multi-process or multi-machine pipelines. Each model segment is hosted on a remote worker, and communication is managed via RRef (Remote References):
class Shard(nn.Module):
def __init__(self, layers, device):
super().__init__()
self.net = layers.to(device)
self.device = device
def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
return self.net(x).cpu()
class DistributedPipeline(nn.Module):
def __init__(self, workers):
super().__init__()
self.shard0 = rpc.remote(workers[0], Shard, args=(stage0, "cuda:0"))
self.shard1 = rpc.remote(workers[1], Shard, args=(stage1, "cuda:1"))
def forward(self, x):
futures = []
for micro_batch in x.split(micro_batch_size):
rref = RRef(micro_batch)
out0 = self.shard0.remote().forward(rref)
out1 = self.shard1.rpc_async().forward(out0)
futures.append(out1)
return torch.cat(torch.futures.wait_all(futures))
Training integrates dist_autograd and DistributedOptimizer to handle gradients and parameter updates across workers transparently.
Using PyTorch’s Built-in Pipe
PyTorch provides a high-level Pipe API that automates much of the pipeline logic:
from torch.distributed.pipeline.sync import Pipe
# Initialize RPC
torch.distributed.rpc.init_rpc("worker", rank=0, world_size=1)
# Define sequential model with layers on different devices
model = nn.Sequential(
nn.Linear(16, 8).to('cuda:0'),
nn.Linear(8, 4).to('cuda:1')
)
# Wrap with Pipe
pipeline_model = Pipe(model, chunks=8)
# Run inference
input_tensor = torch.randn(16, 16).to('cuda:0')
output = pipeline_model(input_tensor).to_here()
The Pipe class internally manages micro-batching, device placement, and inter-device communication using RPC, offering a concise interface for pipeline parallelism.
In practice, pipeline parallelism is rarely used in isolation. Modern large-model training systems combine it with tensor parallelism (within a device) and data parallelism (across device groups) to maximize throughput and scalability.