ICML 2024 ES-FoMo workshop (short paper), PMLR 235 · 2024 ·
My reading notes
Directly relevant to Arun's distributed-training and ML-systems track: a concrete, debuggable example of an async D2D-copy vs AllGather race that produces NaN divergence, the kind of correctness pitfall that surfaces when training large models on commodity clusters like MSI/Slurm. Reinforces that throughput optimizations (hpZ) can quietly corrupt convergence unless synchronization is explicit.
This short LinkedIn paper diagnoses why DeepSpeed ZeRO++'s hierarchical partitioning scheme (hpZ) causes large transformer fine-tuning to diverge to NaN on bandwidth-constrained, multi-node GPU clusters. hpZ keeps a full secondary copy of each layer's weights replicated within a node so that backward-pass AllGather stays intra-node (high NVLink bandwidth) instead of crossing slow inter-node links. The secondary copy is allocated as an uninitialized (torch.empty) tensor and filled by an asynchronous device-to-device (D2D) memcpy. Because ZeRO prefetches AllGather kernels for upcoming modules, the AllGather on the secondary copy can be enqueued and launched before the async memcpy has finished writing it, so the collective aggregates arbitrary uninitialized values. The result is corrupted parameters propagated across devices and unstable, divergent training.
The fix is deliberately minimal: insert an explicit CUDA synchronization point between the D2D memcpy and the AllGather, guaranteeing the secondary copy is fully materialized before any collective reads it. This removes the race without changing the partitioning strategy or memory layout.
Empirically, on A100 nodes (8 GPUs/node, 600 GB/s NVLink intra-node, deliberately throttled to a single 12.5 Gbps Ethernet NIC per node to emulate commodity networking), full-parameter fine-tuning of Llama-2 (7B/13B/70B) and Falcon-40B on MMLU diverges with stock hpZ but converges reliably with the modified algorithm. The fix preserves hpZ's throughput benefit, reporting up to roughly 98% higher tokens/s/node versus disabling hpZ (for Falcon-40B, measured against a qgZ-only baseline), and validation-loss-per-optimization-step curves are essentially identical to a stable no-hpZ baseline, so reliability is restored at negligible cost to optimization quality.