How to Harness the Power of Jamba: A Guide to Using the Hybrid SSM-Transformer LLM

May 9, 2024 | Educational

Welcome to your user-friendly guide to Jamba, a state-of-the-art, hybrid SSM-Transformer LLM developed by AI21. Jamba not only offers faster throughput compared to traditional models, but it also demonstrates remarkable performance across benchmark tests. Whether you’re a developer or an AI enthusiast, this guide will help you set up and utilize Jamba effectively!

Getting Started with Jamba

Before diving into model usage, let’s first understand what Jamba is. Think of Jamba as a high-performance sports car equipped with advanced technology. It has the capability to accelerate faster while ensuring stability on the road (or in the case of Jamba, in advanced AI tasks).

Requirements to Run Jamba

  • Python Packages: You need to have the latest version of the `transformers` library.
  • Installation Commands:
    • For `transformers`:
      pip install transformers>=4.40.0
    • For Mamba implementations:
      pip install mamba-ssm causal-conv1d>=1.2.0
  • CUDA Device: Ensure you have a CUDA-capable GPU to run the model optimally.

For running the model without optimized Mamba kernels, you have to specify `use_mamba_kernels=False`; however, be warned that this will lead to lower performance.

Running Jamba: The Basic Steps

Now that you have Jamba poised for action, let’s take a look at how to run it.

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))

This code snippet initializes Jamba and generates a continuation for the prompt “In the recent Super Bowl LVIII.” It’s as simple as turning the ignition on your sports car and letting it cruise!

Advanced Techniques: Using Half Precision and 8-Bit Mode

Loading in Half Precision

If you’re looking to optimize memory usage, you can load Jamba with half precision:

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)

Loading in 8-Bit Precision

To leverage 8-bit precision, allowing for even larger context sequences, you would typically follow these steps:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, quantization_config=quantization_config)

Fine-Tuning Jamba

Just as a driver might customize a car’s features for optimal performance, you can fine-tune Jamba for specific tasks:

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)

lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)

trainer.train()

Performance Benchmark

Jamba has been rigorously benchmarked against several tasks:

Benchmark Score
HellaSwag 87.1%
Arc Challenge 64.4%
WinoGrande 82.5%
PIQA 83.2%
MMLU 67.4%
BBH 45.4%
TruthfulQA 46.4%
GSM8K (CoT) 59.9%

Remember to add the ‘BOS’ token to all prompts, as it may not be enabled by default in all evaluation frameworks.

Troubleshooting Tips

If you encounter issues while running Jamba, consider the following steps:

  • Ensure that you have the correct versions of all required libraries installed.
  • Check that your CUDA device is set up and properly recognized by your environment.
  • Review your code for typos or incorrect configurations, especially in the installation of dependencies.

For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Wrapping Up

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Conclusion

Jamba opens exciting avenues for research and application within AI. By following this guide, you’ll be well-equipped to make the most of this powerful model. Happy coding!

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox