Shagun Sodhani - Training large scale models using PyTorch | PyData Global 2023

Learn how to effectively train large-scale ML models with PyTorch, covering distributed training approaches, memory optimization techniques, and scaling best practices.

Key takeaways
  • PyTorch offers multiple approaches for training large-scale models:

    • Distributed Data Parallel (DDP) - For when model fits on GPU but dataset is too large
    • Model Parallelism - For splitting models vertically across multiple GPUs
    • Pipeline Parallelism - For handling multiple batches across GPUs in a pipeline fashion
    • Tensor Parallelism - For when individual layers are too large for a single GPU
    • Fully Sharded Data Parallel (FSDP) - For sharding model state, gradients and optimizer across GPUs
  • Memory optimization techniques include:

    • Mixed precision training (FP16/BF16 instead of FP32)
    • CPU offloading of parameters/gradients
    • Activation checkpointing/recomputation
    • Distributed optimizers
  • Scaling axes for large models:

    • Computing power/resources
    • Dataset size
    • Model parameters
  • Key challenges addressed:

    • Models too large to fit on single GPU
    • Long training times
    • GPU memory constraints
    • Failed GPU handling during training
  • FSDP advantages:

    • Reduces memory usage through sharding
    • Handles model state, gradients and optimizer
    • Composes well with existing PyTorch constructs
    • Good for production workloads
  • Distributed elastic features allow:

    • Continue training when GPUs fail
    • Dynamic scaling of GPU count
    • Fault tolerance during training