In this article, we’ll explore the fascinating journey of instruction-tuning the Stable Diffusion model. You’ll learn how to set up your environment, prepare your data, train the model, and generate new images based on specific instructions. Let’s dive in!
Table of Contents
Motivation
Instruction-tuning is the art of teaching machine learning models to understand and follow commands effectively. Inspired by Google’s FLAN and the innovative concept of InstructPix2Pix, this method enhances the capabilities of Stable Diffusion to perform specific edits on input images based on user-defined instructions.
Data Preparation
Data preparation is crucial for training our models. Following FLAN’s inspiration, we need to organize and curate our data efficiently.
- Cartoonization: Please refer to the
data_preparationdirectory for detailed instructions. - Low-level image processing: Information is available in the dataset card.
Training
A proper setup is essential for effective training. Here’s how you can create an optimal environment:
Dev Environment Setup
We recommend using a Python virtual environment and installing PyTorch (version 1.13.1 with CUDA 11.6) as it’s hardware-dependent. You can install it from the official docs.
Once PyTorch is ready, install the remaining dependencies:
pip install -r requirements.txt
Additionally, for memory-efficient training, installing xformers is helpful but not necessary if you’re using PyTorch 2.0.
Launching Training
Utilizing libraries like diffusers, accelerate, and transformers, you can extend the existing training code to fit your project’s needs. Here’s how you can train using the InstructPix2Pix methodology:
export MODEL_ID=runwayml/stable-diffusion-v1-5
export DATASET_ID=instruction-tuning-sd/cartoonization
export OUTPUT_DIR=cartoonization-scratch
accelerate launch --mixed_precision=fp16 train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=256 \
--random_flip \
--train_batch_size=2 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 \
--checkpoints_total_limit=1 \
--learning_rate=5e-05 \
--lr_warmup_steps=0 \
--mixed_precision=fp16 \
--val_image_url=https://huggingface.co/datasets/diffusers/images-docs/resolved/main/mountain.png \
--validation_prompt="Generate a cartoonized version of the natural image" \
--seed=42 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
--push_to_hub
Once training is complete, you will have a directory with intermediate checkpoints and final outputs.
Models, Datasets, Demo
Models:
- instruction-tuning-sd/scratch-low-level-img-proc
- instruction-tuning-sd/scratch-cartoonizer
- instruction-tuning-sd/cartoonizer
- instruction-tuning-sd/low-level-img-proc
Datasets:
Try out the models interactively with no setup required: Demo
Inference
To perform inference on cartoonization, you can use the following code:
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from diffusers.utils import load_image
model_id = "instruction-tuning-sd/cartoonizer"
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, use_auth_token=True
).to("cuda")
image_path = "https://huggingface.co/datasets/diffusers/images-docs/resolved/main/mountain.png"
image = load_image(image_path)
image = pipeline("Cartoonize the following image", image=image).images[0]
image.save("image.png")
For low-level image processing, you can replace the inference code similarly.
Results
The results for cartoonization and low-level image processing are promising, showcasing the powerful capabilities of instruction-tuned models. For more discussions on results and open questions, refer to our blog post.
Acknowledgements
Special thanks to Alara Dirik and Zhengzhong Tu for their valuable contributions and discussions.
Troubleshooting
If you encounter any challenges during installation or training, consider the following tips:
- Verify that your Python environment has all the necessary packages installed as per the requirements.
- Check for compatibility issues with PyTorch and associated libraries based on your hardware.
- Review the training logs for any error messages that could indicate configuration issues.
- For more advanced scenarios, refer to the specific library documentation for detailed troubleshooting steps.
For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.
At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

