Pipeline Parallelism in Large-Scale AI Model Training

Training large-scale neural networks often exceeds the computational and memory capacity of a single device, neecssitating distributed training strategies. Among these, model parallelism (MP) plays a crucial role by partitioning the model itself across multiple devices. Within MP, pipeline parallelism (PP) stands out as an effective technique that distributes different layers of a model across devices and overlaps computation through micro-batching.

Gpipe: Synchronous Pipeline Execution

Gpipe implements pipeline parallelism by splitting both the model and input batch into segments. The model is divided into consecutive stages, each assigned to a separate device. Input batches are further split into micro-batches, enabling pipelined execution: while one device processes the forward pass of micro-batch j, another may be handling the backward pass of micro-batch j−1.

This approach follows an F-then-B (Forward-then-Backward) schedule: all micro-batches complete their forward passes before any backward pass begins. Although this ensures gradient consistency, it introduces idle periods—known as bubbles—where devices wait for dependencies to resolve.

To mitigate memory pressure, Gpipe supports activation checkpointing: intermediate activations are recomputed during the backward pass instead of being stored, significantly reducing per-device memory usage.

PipeDream: Asynchronous 1F1B Scheduling

PipeDream improves hardware utilization by adopting a 1F1B (One Forward, One Backward) strategy. After computing the forward pass of a micro-batch, a stage immediately computes its corresponding backward pass. This interleaving reduces bubbles and keeps devices busy.

However, 1F1B introduces weight staleness: different micro-batches may use slightly different versions of model weights due to asynchronous updates. While this can affect convergence, empirical results show acceptable stability in practice.

An extension, 1F1B-RR (Round-Robin), handles stages with data parallelism by assigning micro-batches in a round-robin fashion across replicas within a stage, ensuring that forward and backward computations for a given micro-batch occur on the same worker.

Implementation Approaches

Manual CUDA Stream-Based Pipeline

A basic implementation in PyTorch manually splits inputs and pipelines computation across GPUs:

class PipelineResNet50(nn.Module):
    def __init__(self, micro_batch_size=20):
        super().__init__()
        self.micro_batch_size = micro_batch_size
        self.stage0 = nn.Sequential(...).to('cuda:0')
        self.stage1 = nn.Sequential(...).to('cuda:1')
        self.classifier = nn.Linear(...).to('cuda:1')

    def forward(self, x):
        chunks = list(x.split(self.micro_batch_size, dim=0))
        if not chunks:
            return torch.empty(0)
        
        # Warm-up: first chunk through stage0
        prev_out = self.stage0(chunks[0]).to('cuda:1')
        outputs = []

        # Pipeline loop
        for chunk in chunks[1:]:
            # Overlap: stage1 on prev_out + stage0 on next chunk
            hidden = self.stage1(prev_out)
            outputs.append(self.classifier(hidden.flatten(1)))
            prev_out = self.stage0(chunk).to('cuda:1')
        
        # Finalize last chunk
        hidden = self.stage1(prev_out)
        outputs.append(self.classifier(hidden.flatten(1)))
        return torch.cat(outputs, dim=0)

This leverages PyTorch’s asynchronous CUDA execution but uses only default streams, limiting overlap between computation and inter-GPU transfers.

RPC-Based Distributed Pipeline

Using PyTorch’s RPC framework enables true multi-process or multi-machine pipelines. Each model segment is hosted on a remote worker, and communication is managed via RRef (Remote References):

class Shard(nn.Module):
    def __init__(self, layers, device):
        super().__init__()
        self.net = layers.to(device)
        self.device = device

    def forward(self, x_rref):
        x = x_rref.to_here().to(self.device)
        return self.net(x).cpu()

class DistributedPipeline(nn.Module):
    def __init__(self, workers):
        super().__init__()
        self.shard0 = rpc.remote(workers[0], Shard, args=(stage0, "cuda:0"))
        self.shard1 = rpc.remote(workers[1], Shard, args=(stage1, "cuda:1"))

    def forward(self, x):
        futures = []
        for micro_batch in x.split(micro_batch_size):
            rref = RRef(micro_batch)
            out0 = self.shard0.remote().forward(rref)
            out1 = self.shard1.rpc_async().forward(out0)
            futures.append(out1)
        return torch.cat(torch.futures.wait_all(futures))

Training integrates dist_autograd and DistributedOptimizer to handle gradients and parameter updates across workers transparently.

Using PyTorch’s Built-in Pipe

PyTorch provides a high-level Pipe API that automates much of the pipeline logic:

from torch.distributed.pipeline.sync import Pipe

# Initialize RPC
torch.distributed.rpc.init_rpc("worker", rank=0, world_size=1)

# Define sequential model with layers on different devices
model = nn.Sequential(
    nn.Linear(16, 8).to('cuda:0'),
    nn.Linear(8, 4).to('cuda:1')
)

# Wrap with Pipe
pipeline_model = Pipe(model, chunks=8)

# Run inference
input_tensor = torch.randn(16, 16).to('cuda:0')
output = pipeline_model(input_tensor).to_here()

The Pipe class internally manages micro-batching, device placement, and inter-device communication using RPC, offering a concise interface for pipeline parallelism.

In practice, pipeline parallelism is rarely used in isolation. Modern large-model training systems combine it with tensor parallelism (within a device) and data parallelism (across device groups) to maximize throughput and scalability.

Tags: pipeline-parallelism model-parallelism distributed-training pytorch rpc

Posted on Thu, 11 Jun 2026 16:37:39 +0000 by smonkcaptain