Getting Started with the NVIDIA Transformer Engine

Mar 4, 2024 | Data Science

The NVIDIA Transformer Engine (TE) is an advanced library designed to enhance the performance of Transformer models on NVIDIA GPUs. It employs the innovative 8-bit floating point (FP8) precision, significantly reducing memory usage while maintaining accuracy in both training and inference tasks. In this guide, we will walk you through installing the Transformer Engine, provide examples, and offer troubleshooting tips to ensure a smooth experience.

What is Transformer Engine?

Think of Transformer Engine as a high-speed expressway for deep learning models. Just like cars travel faster and consume less fuel on well-constructed roads, Transformer Engine enables models like BERT, GPT, and T5 to operate efficiently on NVIDIA GPUs. By utilizing FP8 precision, which is akin to a new type of fuel that allows for faster speeds without sacrificing accuracy, TE helps your models reach their destination quickly without using up all the resources.

Installation

To get started with Transformer Engine, you’ll need to take care of a few pre-requisites:

  • Linux x86_64
  • CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada
  • NVIDIA Driver supporting CUDA 12.0 or later
  • cuDNN 8.1 or later
  • For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later

Quick Start with Docker

The easiest way to get started is through Docker. Run the following command to use the NGC PyTorch container

docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3

Installation with pip

To install the latest stable version of Transformer Engine, use:

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

This automatically selects the frameworks present on your system and installs the necessary components.

Example Code

Let’s delve into a couple of examples to illustrate how to implement FP8 support using both PyTorch and JAX. Imagine building a house; just as the frame gives structure, the code below provides the foundation for your models.

Example in PyTorch

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device='cuda')

# Create an FP8 recipe.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)
    loss = out.sum()
    loss.backward()

Example in JAX

import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)
    
    def loss_fn(params, other_vars, inp):
        out = model.apply(params=params, **other_vars, inp)
        return jnp.mean(out)
    
    # Initialize models.
    variables = model.init(init_rng, inp)
    other_variables, params = flax.core.pop(variables, params)
    
    # Construct the forward and backward function
    fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
    for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

Troubleshooting

Encountering issues during installation or usage can be frustrating. Here are a few troubleshooting ideas:

  • Ensure your CUDA and cuDNN versions meet the necessary requirements.
  • Check if you have enough RAM if you face installation issues related to FlashAttention-2.
  • If you run into memory issues during compilation, try setting MAX_JOBS=1.
  • For other errors or to connect with other developers, consider reaching out to the community.

For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Conclusion

The NVIDIA Transformer Engine is a powerful tool for accelerating Transformer model training and inference. By leveraging advanced technologies like FP8 precision, you can optimize your deep learning projects for better performance and efficiency.

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox