KD-Lib

Mar 21, 2024 | Data Science

A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization

Installation

From source (recommended)

git clone https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install

From PyPI

pip install KD-Lib

Example Usage

KD-Lib provides a comprehensive framework for implementing knowledge distillation, a process where a smaller student model learns from a larger teacher model. Think of it like a wise elder teaching a younger apprentice. The apprentice (student model) absorbs the knowledge and experiences of the elder (teacher model) to perform remarkably well, despite having fewer resources.

Basic Knowledge Distillation

To implement the simplest version of knowledge distillation from “Distilling the Knowledge in a Neural Network,” you can follow these steps:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD

# Define datasets, dataloaders, models, and optimizers
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_data,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_data,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

teacher_model = your_model
student_model = your_model
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)

# Use KD_Lib
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader, teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train teacher
distiller.train_student(epochs=5, plot_losses=True, save_model=True) # Train student
distiller.evaluate(teacher=False) # Evaluate student
distiller.get_parameters() # Get number of parameters

Training Models in an Online Fashion

To train a collection of models in an online manner, you can use the framework in “Deep Mutual Learning”:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50

# Define datasets, dataloaders, models, and optimizers
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_data,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_data,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)
student_cohort = [student_model_1, student_model_2]
student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)
student_optimizers = [student_optimizer_1, student_optimizer_2]

# Use KD_Lib
distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir='.logs')
distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()

Methods Implemented

Some benchmark results can be found in the logs file. Below is a summary of notable papers and methods implemented in the library:

  • Distilling the Knowledge in a Neural NetworkLink
  • Improved Knowledge Distillation via Teacher AssistantLink
  • Relational Knowledge DistillationLink
  • Distilling Knowledge from Noisy TeachersLink
  • Paying More Attention To The AttentionLink
  • Mean Teachers are Better Role ModelsLink
  • Distilling Task-Specific Knowledge from BERTLink

Troubleshooting

If you encounter any issues during installation or usage of KD-Lib, consider the following troubleshooting tips:

  • Ensure that Python and required libraries are correctly installed. Use pip list to verify the installed packages.
  • Check for compatibility of the PyTorch version with KD-Lib.
  • If you face performance issues, verify that your models are correctly defined and that appropriate data preprocessing is applied.
  • Refer to the Documentation for detailed explanations and examples.
  • For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox