Posted February 8, 20232 yr Authors: Daramfon Akpan (Technical Program Manager - Microsoft), Leopold Cambier (Software Engineer - NVIDIA), and Jon Shelley (Principal TPM Manager - Microsoft) Summary: Azure’s NDm A100 v4-series virtual machines (VMs) results are in line with the NVIDIA DGX A100 system Jax scales efficiently from 1 to 16 VMs. Azure NDm A100 v4-series is the right VM series to handle your training needs for large deep learning models JAX JAX [1] is a new Python machine learning framework initially developed by Google. JAX brings together NumPy, automatic differentiation, distributed computations, just-in-time compilation and fusion in a unified framework for high performance machine learning and other scientific workloads. JAX brings together all those concepts in a high-level language, Python. Many libraries have been built on top of JAX. For instance, FLAX [2] implements many neural network primitives on top of JAX and is used by many machine learning models. T5X [3] is one of such libraries. T5X[3] is a library for training, evaluating, and inferring with JAX models across many scales, with a focus on Transformer-based language models. T5X has been successfully used to train language models with hundreds of billions of parameters on very large datasets such as the Pile dataset [5]. Reaching new horizons on Azure We selected the Azure NDm A100 v4-series to run the training benchmarks for the T5 model with the new NVIDIA JAX framework. The NDm A100 v4 series is Azure’s flagship GPU offerings for AI training and deep learning workloads. These virtual machines are powered by NVIDIA A100 80GB Tensor Core GPUs. These instances have the most GPU memory capacity and network bandwidth available on Azure and are backed by NVIDIA InfiniBand HDR connections to support scaling up and out. T5 Large and XLarge model description The Large T5 model has 770 million parameters, 24 encoding and decoding layers and 16 heads. The XL T5 model is much larger with 3 billion parameters with 32 heads. The results below show data from training both models on the wmt_t2t_ende_v003 dataset, a reasonably small machine translation dataset from English to German. This allows for faster download and pre-processing and is enough to extract relevant performance metrics. In practice, larger datasets such as Pile [5] are used for end-to-end training. Figure 1: Showing performance results for training T5 ”large” model running with NVIDIA JAX container on Azure Figure 2: Showing performance results for training T5 ”XLarge” model running with NVIDIA JAX container on Azure The results highlight good scaling from 1 to 16 nodes on both the Large and XLarge T5 models running with JAX on Azure. The Large T5 model has a scaling efficiency of 84% at 16 nodes (128 GPUs) while the XL T5 model has a scaling efficiency of 82% at 16 nodes (128 GPUs). The throughput is within 5% as compared to the NVIDIA DGX A100 data reported here. Customers can now use the JAX framework on Azure when training Large Language Models (LLMs) with solid scaling performance. We invite you to learn more about how Azure can help you accelerate your JAX workloads using the links below. How to replicate the results above on Azure How to replicate T5 benchmarks on NVIDIA DGX A100 Azure NDm A100 v4-series HPC + AI Azure blog References [1] Bradbury, James, et al. “{JAX}: composable transformations of {P}ython+{N}um{P}y programs “. Available at GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more (2018) [2] Heek, Jonathan, et al. “Flax: A Neural Network Library and Ecosystem for JAX”. Available at GitHub - google/flax: Flax is a neural network library for JAX that is designed for flexibility. (2020). [3] Roberts, Adam, et al. "Scaling Up Models and Data with T5X and seqio." arXiv preprint arXiv:2203.17189 (2022). [4] Raffel, Colin, et al. "Exploring the limits of transfer learning with a unified text-to-text transformer." J. Mach. Learn. Res. 21.140 (2020): 1-67 [5] Gao, Leo, et al. "The pile: An 800gb dataset of diverse text for language modeling." arXiv preprint arXiv:2101.00027 (2020). Continue reading...
Join the conversation
You can post now and register later. If you have an account, sign in now to post with your account.