Unlocking Estonian Speech Recognition with XLSR Wav2Vec2

Category :

Do you want to dive into the world of Automatic Speech Recognition (ASR) using the Wav2Vec2 model fine-tuned on the Estonian language? This guide is here to help you get started on your journey.

Understanding Wav2Vec2 and Its Capabilities

The Wav2Vec2 model is like a sophisticated translator that doesn’t just understand words but can make sense of the sound waves behind them. Imagine having a friend who not only speaks two languages but can also understand different dialects and accents. That’s what Wav2Vec2 does—it captures the essence of speech, making it incredibly useful for various languages, including Estonian.

Getting Started

To harness the power of Wav2Vec2 for Estonian speech recognition, follow these steps:

  • Set Up Your Environment:
    • Ensure you have Python and the required libraries, including torch, torchaudio, and huggingface’s transformers, installed.
    • Install necessary datasets from Common Voice.

Usage Instructions

Once your environment is set up, you can start using the Wav2Vec2 model with the following code:

import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Load data
test_dataset = load_dataset("common_voice", "et", split="test[:2%]")

# Load processor and model
processor = Wav2Vec2Processor.from_pretrained("vasilis/wav2vec2-large-xlsr-53-Estonian")
model = Wav2Vec2ForCTC.from_pretrained("vasilis/wav2vec2-large-xlsr-53-Estonian")

# Ensure audio input is sampled at 16kHz
resampler = torchaudio.transforms.Resample(48000, 16000)

# Preprocess audio files
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

inputs = processor(test_dataset["speech"][:2], sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
    predicted_ids = torch.argmax(logits, dim=-1)

print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2]) 

How the Code Works: An Analogy

Think of using the Wav2Vec2 model like inviting a chef into your kitchen. The chef (our model) specializes in a particular cuisine (the Estonian language). When you want to create a dish (decode speech), you need certain ingredients (audio data) prepared in a specific way (sampled at 16kHz). The chef transforms these ingredients into a delicious meal (predicted speech output) based on his understanding of the recipe (training on Common Voice dataset).

Evaluation of the Model

To evaluate the model’s accuracy on Estonian speech recognition, you can execute a straightforward assessment process:

import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Load test data and metrics
test_dataset = load_dataset("common_voice", "et", split="test")
wer = load_metric("wer")

processor = Wav2Vec2Processor.from_pretrained("vasilis/wav2vec2-large-xlsr-53-Estonian")
model = Wav2Vec2ForCTC.from_pretrained("vasilis/wav2vec2-large-xlsr-53-Estonian")
model.to("cuda")

# Define ignored characters
chars_to_ignore_regex = "[,?.!-;:“%‘”]"

# Resampling audio files function
resampler = {
    48000: torchaudio.transforms.Resample(48000, 16000),
    44100: torchaudio.transforms.Resample(44100, 16000),
    32000: torchaudio.transforms.Resample(32000, 16000),
}

# Preprocess audio files for evaluation
def speech_file_to_array_fn(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower()
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler[sampling_rate](speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

# Evaluate model's performance
def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch

result = test_dataset.map(evaluate, batched=True, batch_size=8)

print("WER: {:.2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))

Troubleshooting

If you encounter issues, here are some troubleshooting tips to consider:

  • Ensure your audio input is sampled at 16kHz; otherwise, resampling may cause errors.
  • Verify the model ID and dataset name to avoid common misconfigurations.
  • If there’s a change in the dataset format, you might need to adapt your code accordingly.

For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Model Performance

The test results lead to a Word Error Rate (WER) of approximately 30.66%, indicating room for improvements and potential training adjustments.

Conclusion

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox

Latest Insights

© 2024 All Rights Reserved

×