With container preloading, we measured the image pull operation for a 16GB container to be about 29X faster than the baseline (image pull from container registry).
Persistent compilation cache
Just-in-time compilation and system-aware optimizations are one of the key enablers for an XLA compiler-based computation stack. In most performant training loops, computation graphs are compiled once and executed many times with different input data. A compilation cache prevents recompilation if the graph shapes stay the same. In the event of a failure or interruption, this cache may be lost, thereby slowing down the training resumption process, adversely affecting the Runtime Goodput. A persistent compilation cache helps solve this problem by allowing users to save compilation cache to Cloud Storage such that the cache persists across restart events.
Furthermore, GKE, the recommended orchestration layer for AI Hypercomputer, has also made recent advancements to improve the job-scheduling throughput by 3X, helping reduce time to resume (trm).
Maximizing Program Goodput
Program Goodput or Model Flop Utilization depends on the efficient utilization of the underlying compute as the training program makes forward progress. Distribution strategy, efficient compute communication overlap, optimized memory access and designing efficient pipelines contribute to Program Goodput. XLA compiler is one of the core components of AI Hypercomputer designed to help you maximize the Program Goodput by out-of-the box optimizations and simple and performant scaling APIs such as GSPMD, which allows users to easily express a wide range of parallelisms to efficiently leverage scale. We recently introduced three key features to help Jax and PyTorch/XLA users maximize Program Goodput.
Custom Kernel with XLA
In compiler-driven computation optimization, often we need an “escape hatch,” which allows users to write more efficient implementations using fundamental primitives for complex computation blocks, pushing past the default performance. Jax/Pallas is the library built to support custom kernels for Cloud TPUs and GPUs. It supports both Jax and PyTorch/XLA. Some examples of custom kernels written using Pallas include Flash Attention or Block Sparse Kernels. The Flash attention kernel helps to improve Program Goodput or Model Flop Utilization for larger sequence lengths (more pronounced for sequence lengths 4K or above).
Host offload
For large-scale model training, accelerator memory is a limited resource and we often make trade-offs such as activation re-materialization to trade off compute cycles for accelerator memory resources. Host offload is another technique we recently introduced in the XLA compiler to leverage host DRAM to offload activations computed during the forward pass and reuse them during the backward pass for gradient computation; this saves activation recomputation cycles and therefore improves Program Goodput.
Int8 Mixed Precision Training using AQT
Accurated Quantized Training is another technique that maps a subset of matrix multiplications in the training step to int8 to boost training efficiency and therefore Program Goodput without compromising convergence.
The following benchmark shows aforementioned techniques used in conjunction to boost program goodput for a 128b dense LLM implementation using MaxText.