Sharded
Published:
🚀 PyTorch Parallelism Explained: DP vs DDP vs FSDP
In distributed training, choosing the right parallelism strategy is key to optimizing memory and compute efficiency. Here’s a breakdown of the most common methods:
🔹 Data Parallel (DP)
- What: Legacy, single-process, multi-GPU.
- How: Replicates the entire model on each GPU and splits the input batch.
- Sync: Gradients are averaged after the backward pass.
- Downside: Slow and memory-inefficient due to frequent CPU-GPU copies and single-process bottlenecks.
🔹 DistributedDataParallel (DDP)
- What: Modern standard for data parallelism across multiple processes (and GPUs).
- How: One process per GPU; model is replicated per GPU.
- Sync: Gradients are synchronized across processes during backward pass.
Pros:
- Fast, scalable.
- Uses NCCL backend for efficient communication.
- Best for: Most multi-GPU training cases today.
🔹 Fully Sharded Data Parallel (FSDP)
- What: Advanced sharded training introduced by Meta.
How:
- Model weights, gradients, and optimizer states are sharded across GPUs.
- Only needed parts are materialized during forward/backward.
Pros:
- Huge memory savings (e.g., train 65B models with 8xA100).
- Supports mixed precision and activation checkpointing.
- Best for: Very large models (billions of parameters).
🔹 Summary Table
Method | Model Copy per GPU | Gradient Sync | Memory Usage | Scalability |
---|---|---|---|---|
DP | Yes (1 process) | After backward | High | Poor |
DDP | Yes (multi-process) | During backward | Moderate | Excellent |
FSDP | Sharded | During backward | Low | Excellent |
💡 TL;DR:
- Use DDP for most multi-GPU workloads.
- Use FSDP when training very large models and you need to save memory.
- Avoid DP — it’s outdated.