How to Use Jamba: The Hybrid SSM-Transformer LLM

May 7, 2024 | Educational

Welcome to the exciting world of Jamba, a cutting-edge model that brings together the best of SSM (Structured State Machine) and Transformer architectures. Whether you’re a developer, researcher, or simply an AI enthusiast, this guide will walk you through using Jamba effectively.

What is Jamba?

Jamba is a state-of-the-art, hybrid SSM-Transformer LLM designed to deliver exceptional performance in terms of throughput while competing with leading models in its class. Featuring 12 billion active parameters and a total of 52 billion across all experts, Jamba can process a whopping 256K context length and fits up to 140,000 tokens on a single 80GB GPU. This makes it a fantastic tool for a variety of applications.

Getting Started with Jamba

Follow these steps to install and run Jamba smoothly on your machine:

Prerequisites

  • You need to have transformers version 4.40.0 or higher:
  • pip install transformers>=4.40.0
  • For optimized Mamba implementations, install mamba-ssm and causal-conv1d:
  • pip install mamba-ssm causal-conv1d>=1.2.0
  • Ensure that your model is on a CUDA device to optimize performance.

Run the Model

Now that you’ve set up the prerequisites, running the Jamba model is a breeze:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))

Understanding the Code: An Analogy

Imagine you are a chef in a large kitchen (the model) where various ingredients (the tokens) are available to create marvelous dishes (the output). When you load Jamba, it’s like having your high-tech kitchen apparatus ready: the oven, mixer, and blender (the libraries and libraries) organized and operational.

In our snippet, you prepare a dish by gathering ingredients (input tokens from the prompt), and the magic happens when you use the oven (the generate method) to combine everything, producing a delectable meal (the output text).

Loading the Model in Special Modes

Jamba provides flexibility in how you utilize its power, allowing for options like half precision, which can significantly enhance performance:

Loading the Model in Half Precision
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)
Load the Model in 8-Bit Precision

You can easily quantize the model to 8-bit and fit up to 140K sequence lengths on a single 80GB GPU. Here’s how:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, quantization_config=quantization_config)

Fine-Tuning Jamba

Jamba is highly customizable. You can fine-tune it to develop chat or instruct versions as per your needs. Here’s a basic example of how to do this:

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)

lora_config = LoraConfig(r=8, target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"])
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

Troubleshooting Tips

If you encounter any issues while using Jamba, consider the following troubleshooting tips:

  • Ensure you have all the required packages installed with the correct versions.
  • Check if your model is on a CUDA device if running into performance issues.
  • If you face difficulties with quantization or half precision, confirm you have the necessary libraries.
  • For early-stage models like Jamba, aligning prompts correctly with the ‘BOS’ token is essential for accurate outputs.

For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Conclusion

With Jamba, the potential for AI-driven applications is expansive. At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox