We managed TPU capacity with Google Kubernetes Engine (GKE) and utilized XPK on top of GKE to orchestrate ML jobs. XPK handles creating clusters, resizes them as necessary, submits jobs into the GKE Kueue system as JobSets, manages those Jobsets, and provides visibility into the state of the cluster.
To accelerate model training, we used the Accurate Quantized Training (AQT) library to train in quantized INT8. As of October 2023, this approach enables a 1.2X to 1.4X acceleration in steps per second while producing a convergence gap smaller than that normally associated with quantizing a model trained in BF16 to INT8.
How we scaled the largest distributed LLM training job
As we scaled our TPU compute cluster, we began to push the limits of the stack.
Orchestration
Managing over 50,000 accelerator chips working on a single training job requires a well-designed orchestration solution that enables both different users to submit smaller jobs for experiments as well as supporting a full-scale job that runs on the entire cluster. This functionality is provided through GKE’s Jobset and Kueue features. As we pushed the limits on the number of VMs that GKE could handle, we optimized managing internal IP addresses, precaching docker images, designed clusters for scale, and enabled high-throughput scheduling. We also optimized GKE to push VM scaling limits in areas such as pod IP exhaustion, Domain Name Service (DNS) scalability, and control-plane node limits. We packaged and documented these solutions alongside XPK to make this a repeatable process for customers training at this massive scale.
Performance
JAX is powered by XLA (Accelerated Linear Algebra), a compiler-based linear algebra execution engine that optimizes workloads for ML accelerators like TPUs and GPUs to deliver supercomputer-like performance. The key parallelism technique behind XLA is SPMD (single program, multiple data) where the same computation is run in parallel on different devices. XLA leverages GSPMD which simplifies SPMD programming by allowing the user to program a single giant supercomputer and then automatically parallelizing the computation across devices based on a few user annotations. Running at large scale exposed the need for optimizations that only become necessary with a large number of slices. For example, each worker VM needs to communicate over DCN with the worker VMs of the same rank in other slices. Originally, this caused slowdowns due to excess device-to-host and host-to-device transfers that scaled linearly with the number of slices. By optimizing the XLA runtime, we were able to prevent these transfers from being the bottleneck.
Storage
Interacting with persistent storage is a crucial aspect of training. Our 199-pod cluster had 1 Tb/s to Google Cloud Storage (GCS), 1,270 Tb/s inter-slice DCN, and 73,400 Tb/s intra-slice ICI. When loading docker images, loading data, and reading/writing checkpoints, we optimized the interaction with persistent storage.
We found that at large scale, data loading from GCS began to affect performance, starting at 64 pods scale. We have since mitigated this limit with a distributed data loading strategy that alleviates pressure on GCS by having a subset of hosts load data.
We also found limits due to checkpointing. By default, checkpointing loads the full checkpoint into each data parallel replica from GCS. Consider checkpoint loading for a 128B model sharded with cross-pod data parallelism. For a traditional optimizer state of three numbers per parameter (4bytes/number), this means loading a checkpoint of size ~1.536 TB separately into each pod (in this case for 199 pods). This would require 199 pods * 1.536TB/pod, or approximately 300TB of bandwidth. For reasonable performance from persistent storage of 1 Tb/s, this would require approximately 2,400 seconds (40 minutes). However, we needed much lower start or restart time, so had to take a different approach.
To alleviate the problem, we added features that enabled a single pod to load the checkpoint and broadcast it to the other replicas. As a result, a single pod can read the checkpoint and then broadcast the optimizer state to the other pods by leveraging the flexibility of JAX. In principle, this should take 1.536TB/1Tb/s = ~12 seconds to load the checkpoint and then (2*1.536TB/pod) / (64 VM/per pod * 100 Gb/s/VM) = ~4 seconds to gather the optimizer state across the cluster, for a total of 16 seconds and a 150x speedup. Similarly, optimizations are needed when writing checkpoint data and loading training data. At write-time, a single leader replica can then write the entire checkpoint to avoid excessive QPS to GCS.
How did we measure training performance?
Training performance is measured in terms of Model FLOPs Utilization (MFU) and Effective Model FLOPs Utilization (EMFU). For an N-parameter decoder-only model, each token seen requires 6N matmul FLOPs for the learnable weights and 12LHQT matmul FLOPs for attention where L, H, Q, and T are the number of layers, the number of heads, the head dimension, and the sequence length respectively (see Appendix B of the PaLM paper for more details). Knowing the TFLOPs required for each token, we can represent the throughput of a step as observed TFLOP/chip/s which is computed as the total TFLOPs required for all tokens seen in the step for each chip divided by the step time.