How to Scale Your Model: A Systems View of LLMs on TPUs

Jacob Austin, Sholto Douglas, Roy Frostig, Anselm Levskaya, Charlie Chen, Reiner Pope, et al. (Google DeepMind)

Online book (jax-ml.github.io/scaling-book), Google DeepMind · 2025 · ★★★★½4.5/5

My reading notes

Why it matters

Directly serves Arun's ML-systems track: it gives a quantitative, back-of-envelope framework for choosing parallelism schemes, estimating training/inference cost, and reasoning about communication vs compute bounds, the exact skills behind MIRROR's Slurm training and Kubernetes serving. The inference chapters (KV cache, continuous batching, prefix caching, prefill/generation split) map straight onto his interest in LLM inference serving and scheduling.

Summary

This is a 141-page online book from Google DeepMind that treats LLM performance as a tractable engineering problem rather than black magic. The core thesis is that a handful of first-principles tools, primarily roofline analysis (is a given operation bound by compute, memory bandwidth, or inter-chip communication?), let you predict how a Transformer will behave on real hardware and pick the right way to split it across many chips. The recurring goal is "strong scaling": adding chips and getting a near-linear throughput gain, which fails once communication overtakes the shrinking per-device compute you could otherwise use to hide it.

The book builds bottom-up. Early chapters cover rooflines, the internals of TPUs (systolic arrays, memory hierarchy, ICI/inter-chip networking) with a comparison appendix on GPUs, and a unified notation for sharded matrix multiplication plus the collective operations they require (AllGather, ReduceScatter, AllToAll, AllReduce). A "Transformer math" chapter teaches careful counting of parameters and FLOPs for every matmul, including MoE sparsity, gradient checkpointing, and KV caching. The two central chapters analyze training (data parallelism, FSDP/ZeRO, tensor parallelism, pipelining, plus memory-saving techniques, each with the explicit crossover point where it becomes communication-bound) and inference (latency vs throughput, prefill vs generation, KV-cache sharding, speculative sampling, continuous batching, prefix caching, and an inference-engine design discussion referencing JetStream).

Two hands-on chapters apply the whole framework to LLaMA-3 on TPU v5p/v5e pods, focusing mainly on the 70B model for both training and serving (the 8B and 405B variants are largely left to the exercises). The final chapters are practical JAX: parallelism via jax.jit and shard_map, plus how to read the XLA/JAX profiler (trace viewer, graph viewer, memory profile) to debug real workloads. Throughout, the authors pose worked problems with solutions, making it usable as a self-study course.

Key ideas

Takeaways for my work

LLM scalingdistributed traininginference servingTPU/GPU systemsparallelism