LESSONS · 11 · 03 / 11
Distributed Training Infrastructure
Learn about frameworks and approaches for distributed training, including DeepSpeed and FSDP, along with monitoring techniques.
Overview
In our previous lesson, we explored parameter-efficient fine-tuning techniques that enable working with large language models on limited hardware. However, as model sizes continue to grow beyond even what PEFT methods can handle on a single device, distributed training becomes essential. This lesson explores how to scale training across multiple GPUs, multiple machines, and even multiple data centers.
Distributed training is what makes training models with hundreds of billions of parameters possible. Understanding these techniques allows you to work with state-of-the-art models and contribute to pushing the boundaries of what's possible in natural language processing.
Before diving into the theory, get a feel for why distribution is unavoidable. This interactive explorer breaks down where GPU memory actually goes for a given model:
TIP▶ Try this first. Open the TrainingExplorer below and grow the model size while watching how the memory breakdown shifts between parameters, optimizer states, and activations. Notice the point where these components stop fitting on one device — that crossover is the whole reason distributed training exists. Come back to the theory once you've seen it move.
Try it: Bump the model size up to 10B+ parameters and watch the optimizer-state slice balloon past the raw parameter memory — that crossover is exactly the moment a single GPU stops being enough.
Learning Objectives
After completing this lesson, you will be able to:
- Understand the fundamental challenges of distributed training for large language models
- Identify and implement appropriate parallelism strategies based on model and hardware constraints
- Set up and configure popular distributed training frameworks like DeepSpeed and PyTorch FSDP
- Optimize distributed training performance through proper hyperparameter tuning
- Implement effective monitoring and debugging strategies for distributed training
- Design resilient training systems that can recover from failures
The Need for Distributed Training
Scale Drives Progress
The past few years have demonstrated a clear trend: larger models, trained on more data, perform better on a wide range of tasks. This scaling law presents a technical challenge—how do we train these enormous models efficiently?
Analogy: Distributed Training as Coordinated Construction
Think of training a large language model like constructing a massive skyscraper:
- Single-device Training: One construction team trying to build the entire structure—impossible beyond a certain size
- Distributed Training: Multiple specialized teams working on different sections simultaneously
- Coordination Overhead: Teams need to communicate, synchronize, and integrate their work
- Resource Planning: Different tasks require different equipment and expertise
Just as a skyscraper can only reach new heights through coordinated teamwork, today's largest language models can only be trained through sophisticated distributed systems.
The Fundamental Challenges
The scaling challenges become apparent when we examine how resource requirements grow with model size:
| Model Size | Memory Required* | Training Time** | Single GPU Feasible? |
|---|---|---|---|
| 1B parameters | 24 GB | 1x (baseline) | ✅ High-end GPUs |
| 10B parameters | 240 GB | 15x | ❌ Requires distributed |
| 100B parameters | 2.4 TB | 225x | ❌ Requires sophisticated sharding |
| 1T parameters | 24 TB | 3,375x | ❌ Requires massive infrastructure |
*Including model parameters, optimizer states (Adam), and activations
**Relative to 1B parameter baseline
Key Observations:
- Memory requirements scale linearly with parameters, but optimizer states multiply this by ~4x
- Training time scales super-linearly due to communication overhead and longer sequences
- Beyond 10B parameters, distributed training becomes mandatory
Memory Constraints
A fundamental challenge in training large language models is memory:
-
Model Parameters: FP16 parameters require 2 bytes each
- 1B parameters = 2GB
- 100B parameters = 200GB
- 1T parameters = 2TB
-
Optimizer States: Optimizers like Adam require additional memory
- Adam needs 8 bytes per parameter (4x model size)
- 1B parameters = 10GB total
- 100B parameters = 1TB total
-
Activation Memory: Forward pass outputs needed for backpropagation
- Scales with batch size and sequence length
- Can often exceed parameter memory for large batches
-
Gradient Accumulation: Reduces memory but increases training time
Revisit the GPU-memory explorer at the top of this lesson to see exactly how these four components stack up as model size grows.
Parallelism Strategies: Dividing the Problem
Types of Parallelism
To overcome these challenges, we use multiple forms of parallelism:
🔄 Data Parallelism - Split the data across devices
- DDP (Distributed Data Parallel): Replicate model, sync gradients
- ZeRO: Partition optimizer states, gradients, and parameters
🧩 Model Parallelism - Split the model across devices
- Tensor Parallelism: Split individual layers (matrix operations)
- Pipeline Parallelism: Split model into sequential stages
📦 3D Parallelism - Combine all approaches for maximum scale
Data Parallelism
In data parallelism, the entire model is replicated across devices, but each processes different batches of data:
- Each device maintains a complete copy of the model
- Each device processes different data samples
- Gradients are synchronized across devices
- Model weights are updated synchronously
Interactive Visualization: Explore how data parallelism distributes training across GPUs:
Distributed Data Parallel (DDP)
The standard approach to data parallelism in PyTorch:
import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler def setup(rank, world_size): \"\"\"Initialize distributed process group.\"\"\" dist.init_process_group(backend='nccl', init_method='tcp://localhost:12355', world_size=world_size, rank=rank) def cleanup(): \"\"\"Clean up distributed process group.\"\"\" dist.destroy_process_group() class SimpleModel(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768) ) def forward(self, x): return self.layers(x) def train(rank, world_size, num_epochs=10): # Initialize distributed training setup(rank, world_size) # Create model and move to GPU model = SimpleModel().to(rank) # Wrap model with DDP ddp_model = DDP(model, device_ids=[rank]) # Create optimizer and loss function optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=5e-5) criterion = torch.nn.CrossEntropyLoss() # Create dataset and sampler # Replace with your actual dataset dataset = torch.utils.data.TensorDataset( torch.randn(1000, 768), # Example input data torch.randint(0, 10, (1000,)) # Example labels ) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) # Training loop for epoch in range(num_epochs): sampler.set_epoch(epoch) # Important for proper shuffling for batch in dataloader: inputs, labels = batch inputs, labels = inputs.to(rank), labels.to(rank) # Forward pass outputs = ddp_model(inputs) loss = criterion(outputs, labels) # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() cleanup() # Launch with: torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
Zero Redundancy Optimizer (ZeRO)
ZeRO enhances data parallelism by partitioning optimizer states, gradients, and even parameters across GPUs:
- ZeRO Stage 1: Partitions optimizer states (momentum, variance)
- ZeRO Stage 2: Partitions gradients as well
- ZeRO Stage 3: Partitions model parameters too
This strategy dramatically reduces memory overhead while maintaining the simplicity of data parallelism.
Interactive Visualization: See how ZeRO partitions optimizer states, gradients, and parameters:
Continue this lesson with Premium
You've reached the end of the free preview. Premium unlocks the full lesson, every advanced track, and the source for all instruments.
- ◆Every premium lesson, unlocked
- ◆Pay what you want — $1 to $100
- ◆6 months of full access