How to Fine-tune XLSR Wav2Vec2 for Swedish Speech Recognition

Feb 5, 2024 | Educational

In this guide, we will walk you through the process of fine-tuning the XLSR Wav2Vec2 model for automatic speech recognition (ASR) specifically for the Swedish language. This technique enhances the model’s ability to transcribe audio data accurately. Let’s dive in!

Understanding the Model and Dataset

Before we jump into the code, it’s essential to understand the components:

  • Model: XLSR Wav2Vec2, a state-of-the-art model by Facebook designed for multilingual and low-resource automatic speech recognition.
  • Dataset: We will use the Common Voice dataset, which provides a rich array of spoken Swedish data. The training and evaluation metrics we will focus on are Word Error Rate (WER) and Character Error Rate (CER).

Setup and Dependencies

To get started, make sure you have the following Python packages installed:

  • torch
  • torchaudio
  • datasets
  • transformers

You can install them via pip:

pip install torch torchaudio datasets transformers

Loading the Model

Let’s load the required model for speech recognition:

python
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Load datasets
test_dataset = load_dataset("common_voice", "sv-SE", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("KBLab/wav2vec2-large-xlsr-53-swedish")
model = Wav2Vec2ForCTC.from_pretrained("KBLab/wav2vec2-large-xlsr-53-swedish")

Preprocessing the Audio Data

Before feeding audio data to the model, it must be preprocessed to fit the model’s input requirements:

# Resample audio data
resampler = torchaudio.transforms.Resample(48_000, 16_000)

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

# Apply the function to preprocess the dataset
test_dataset = test_dataset.map(speech_file_to_array_fn)

Making Predictions

With the model and dataset ready, we can now run predictions:

# Make predictions
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)

print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])

Evaluating the Model

Evaluate the performance using WER and CER metrics:

# Importing metrics for evaluation
from datasets import load_metric

wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("KBLab/wav2vec2-large-xlsr-53-swedish")
model = Wav2Vec2ForCTC.from_pretrained("KBLab/wav2vec2-large-xlsr-53-swedish").to("cuda")

def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch

result = test_dataset.map(evaluate, batched=True, batch_size=8)
print("WER:", 100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"]))
print("CER:", 100 * wer.compute(predictions=["".join(list(entry)) for entry in result["pred_strings"]], references=["".join(list(entry)) for entry in result["sentence"]]))

Training the Model

This model can be trained with various datasets, starting with a robust corpus of 1000 hours of spoken Swedish edited from different radio stations. Then you can fine-tune using NST Swedish Dictation and the Common Voice dataset to improve performance.

Troubleshooting

If you encounter issues during usage or observations of low accuracy in predictions, consider these troubleshooting tips:

  • Ensure your audio input is sampled at 16kHz. Resampling may be necessary if your audio isn’t at this frequency.
  • Check for any irregularities in the dataset paths or conditions while loading the datasets.
  • Monitor your GPU usage to ensure you’re not running out of memory during the evaluation process.

For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Conclusion

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox