Enhancing Stability for Large Models Training in Constrained Bandwidth Networks

Yun Dai, Tejas Dharamsi, Byron Hsu, Tao Song, Hamed Firooz (LinkedIn)

ICML 2024 ES-FoMo workshop (short paper), PMLR 235 · 2024 · ★★★½☆3.5/5

My reading notes

Why it matters

Directly relevant to Arun's distributed-training and ML-systems track: a concrete, debuggable example of an async D2D-copy vs AllGather race that produces NaN divergence, the kind of correctness pitfall that surfaces when training large models on commodity clusters like MSI/Slurm. Reinforces that throughput optimizations (hpZ) can quietly corrupt convergence unless synchronization is explicit.

Summary

This short LinkedIn paper diagnoses why DeepSpeed ZeRO++'s hierarchical partitioning scheme (hpZ) causes large transformer fine-tuning to diverge to NaN on bandwidth-constrained, multi-node GPU clusters. hpZ keeps a full secondary copy of each layer's weights replicated within a node so that backward-pass AllGather stays intra-node (high NVLink bandwidth) instead of crossing slow inter-node links. The secondary copy is allocated as an uninitialized (torch.empty) tensor and filled by an asynchronous device-to-device (D2D) memcpy. Because ZeRO prefetches AllGather kernels for upcoming modules, the AllGather on the secondary copy can be enqueued and launched before the async memcpy has finished writing it, so the collective aggregates arbitrary uninitialized values. The result is corrupted parameters propagated across devices and unstable, divergent training.

The fix is deliberately minimal: insert an explicit CUDA synchronization point between the D2D memcpy and the AllGather, guaranteeing the secondary copy is fully materialized before any collective reads it. This removes the race without changing the partitioning strategy or memory layout.

Empirically, on A100 nodes (8 GPUs/node, 600 GB/s NVLink intra-node, deliberately throttled to a single 12.5 Gbps Ethernet NIC per node to emulate commodity networking), full-parameter fine-tuning of Llama-2 (7B/13B/70B) and Falcon-40B on MMLU diverges with stock hpZ but converges reliably with the modified algorithm. The fix preserves hpZ's throughput benefit, reporting up to roughly 98% higher tokens/s/node versus disabling hpZ (for Falcon-40B, measured against a qgZ-only baseline), and validation-loss-per-optimization-step curves are essentially identical to a stable no-hpZ baseline, so reliability is restored at negligible cost to optimization quality.

Key ideas

Takeaways for my work

distributed-trainingml-systemsZeRO++/DeepSpeedLLM-training-stabilityGPU-synchronization