in

Performance deepdive of Gemma on Google Cloud


Earlier this year we announced Gemma, an open weights model family built to enable developers to rapidly experiment with, adapt, and productionize on Google Cloud. Gemma models can run on your laptop, workstation, or on Google Cloud through either Vertex AI or Google Kubernetes Engine (GKE) using your choice of Cloud GPUs or Cloud TPUs. This includes training, fine-tuning, and inference using PyTorch and JAX, leveraging vLLM, HuggingFace TGI, and TensorRT LLM on Cloud GPUs as well as JetStream and Hugging Face TGI (Optimum-TPU) on Cloud TPUs. 

Our benchmarks indicate up to 3X training efficiency (better performance per dollar) for Gemma models using Cloud TPU v5e when compared to our baseline of Llama-2 training performance. Earlier this week, we released JetStream, a new cost-efficient and high-performance inference engine. We analyzed Gemma inference performance on Cloud TPU and found 3X inference efficiency gain (more inferences per dollar) for LLM inference when serving Gemma on JetStream compared to the prior TPU inference stack that we used as the baseline. 

In this post, we review the training and inference performance of Gemma models on Google Cloud accelerators. The results we present are snapshots in time as of April 2024. We anticipate that the infrastructure efficiency and quality of these models will continue to evolve and improve through the contributions of the open-source community, our enterprise users, and the teams at Google. 

Background: Gemma model architecture details

The Gemma family of models include two variants, Gemma 2B and Gemma 7B (dense decoder architecture). We pre-trained Gemma with 2 trillion and 6 trillion tokens for the 2B and 7B models, respectively, with the context length of 8,192 tokens. Both models use a head dimension of 256, and both variants utilize Rotary Positional Embeddings (RoPE).

Model

d_model

q_heads

kv_heads

d_ff

n_layers

Gemma 2B

2,048

8

1

16,384

18

Gemma 7B

3,072

16

16

24,576

28

While the Gemma 7B model leverages a multihead attention mechanism, Gemma 2B utilizes multi-query attention. This approach aids in reducing memory bandwidth requirements during the inference process, which can potentially be advantageous for Gemma 2B on-device inference scenarios, where memory bandwidth is often limited. 

Gemma training performance

To assess the training infrastructure for a given model or a category of similarly sized models, there are two important dimensions: 1) effective model flops utilization; and 2) relative performance per dollar. 

Effective model flops utilization 

Model FLOPs Utilization (MFU) is the ratio of the model throughput, i.e., the actual floating-point operations per second performed by the model relative to the peak throughput of the underlying training infrastructure. We use the analytical estimate for the number of floating-point operations per training step and the step-time to compute the model throughput (ref. PaLM). When applied to mixed-precision training settings (Int8), the resultant metric is called Effective Model FLOPs Utilization (EMFU). All else being equal, a higher (E)MFU indicates improved performance per unit cost. Improvements in MFU directly translate to cost savings for training. 

Gemma training setup

Pre-training for Gemma models was done internally at Google using TPU v5e. It employed two v5e-256 for Gemma 2B and 16 Cloud TPU v5e-256 for Gemma 7B. 

We measured the (E)MFU for Gemma models on Cloud TPU. We present the performance on both Cloud TPU v5e and Cloud TPU v5p since both are the latest Cloud TPU generations (at the time of writing this post). Cloud TPU v5e is the most cost-efficient TPU to date on a performance per dollar basis. By contrast, Cloud TPU v5p, is the most powerful and scalable TPU available for more complex LLM architectures, such as mixture of experts and alternative workloads such as large ranking and recommendation systems.

The following graph presents the EMFU for the Gemma 2B and Gemma 7B training run with bf16 precision and mixed precision (int8) training (using AQT).


https://storage.googleapis.com/gweb-cloudblog-publish/images/Next24_Blog_blank_2-03.max-2500x2500.jpg

Storage announcements at Next ‘24

https://storage.googleapis.com/gweb-cloudblog-publish/images/Next24_Blog_blank_2-08.max-2500x2500.jpg

Analyze images and videos in BigQuery using Gemini 1.0 Pro Vision