Audio Seq2seq Model using Transformers (original) (raw)

Last Updated : 23 Jul, 2025

The article explores the various applications of the Seq2Seq model in various fields, delving into its complexities. We'll also look at how audio transformation can be used in practice.

What is Seq2Seq model?

Seq2Seq are encoder and decoder models allowing for different lengths of inputs and outputs as the input is processed by the encoder and the output is processed by the decoder. These are typically used for Automatic Speech Recognizations (ASR), Speech to Speech Translation, and Speech Synthesis. Let us see how input is processed through a Seq2Seq model.

For a more detailed understanding of how audio transformers work kindly look into the article Audio Transformer

In this article, we will focus on the implementation of the seq-to-seq model using a transformer. We will use the pre-trained whisper model from Hugginface and fine-tune it.

Audio Seq2seq Model Implementation using Transformers

**Install the Necessary Libraries

Install the below libraries if not available in your environment. These are required to run the subsequent code.

Import the necessary libraries

!pip install datasets !pip install transformers !pip install torch !pip install evaluate !pip install jiwer !pip install transformers[torch] !pip install numpy

Step 1: Import the Necessary Libraries

And then import the libraries into your notebook

Python3 `

##Imports required import numpy as np from datasets import load_dataset, Audio, DatasetDict import torch import evaluate from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union from transformers import Seq2SeqTrainingArguments,Seq2SeqTrainer device = 'cuda' if torch.cuda.is_available() else 'cpu' print(device)

`

**Step 2: Loading Dataset

About the PloyAi/minds14 dataset- MINDS-14 is a training and evaluation resource for intent detection tasks with spoken data. It covers 14 intents extracted from a commercial system in the e-banking domain, associated with spoken examples in 14 diverse language varieties

dataset = DatasetDict()

Load the PolyAI dataset.

dataset = load_dataset("PolyAI/minds14", name="en-US", split="train[:80]")

Remove unnecessary columns

dataset.remove_columns( ['path', 'english_transcription', 'intent_class', 'lang_id'])

Split the datasedataset into train and test

dataset = dataset.train_test_split(test_size=0.2, shuffle=False)

`

**Step3: Data Pre-Processing and Tokenization

We first resample our audio data to 16khz from 8kz using an audio library as the whimper seq to seq model is trained on 16 kHz

Python3 `

dataset['train'] = dataset['train'].cast_column("audio", Audio(sampling_rate=16000)) dataset['test'] = dataset['test'].cast_column("audio", Audio(sampling_rate=16000))

`

Let us import the whisper model and processor from Hugging Face using the Transformers library

from transformers import WhisperProcessor, WhisperForConditionalGeneration processor = WhisperProcessor.from_pretrained( "openai/whisper-tiny.en", task="transcribe", model_max_length=225) model = WhisperForConditionalGeneration.from_pretrained( 'openai/whisper-tiny.en') model.to(device)

Preparing a function to process the entire dataset

def prepare_dataset(batch): audio = batch["audio"]

batch["input_ids"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"],

return_tensor = "pt").input_features[0]

batch["input_features"] = processor.feature_extractor(
    audio["array"], sampling_rate=audio["sampling_rate"], return_tensor="pt").input_features[0]

batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
return batch

np.object = object encoded_dataset = dataset.map( prepare_dataset, remove_columns=data.column_names["train"], num_proc=4)

`

**Step 4: Preparing data collator class

Creating a DataCollatorClass

from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union

@dataclass class DataCollatorSpeechSeq2SeqWithPadding: processor: processor padding: Union[bool, str] = "longest"

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
    # split inputs and labels since they have to be of different lengths and need
    # different padding methods
    # print(features)
    input_features = [{"input_features": feature["input_features"]}
                      for feature in features]
    label_features = [{"input_ids": feature["labels"]}
                      for feature in features]

    batch = self.processor.feature_extractor.pad(
        input_features, return_tensors="pt")

    labels_batch = self.processor.tokenizer.pad(
        label_features, return_tensors="pt")

    # replace padding with -100 to ignore loss correctly
    labels = labels_batch["input_ids"].masked_fill(
        labels_batch.attention_mask.ne(1), -100)

    batch["labels"] = labels
    # print(batch)
    return batch  # batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

`

Step 5: Model Evaluation

We will be evaluating our model on word error rate

Python3 `

Evalution metric-

import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):

wer = evaluate.load("wer")

pred_ids = pred.predictions
label_ids = pred.label_ids

# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

wer =  metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}

`

**Step 6: Define our trainer

model.config.forced_decoder_ids = None model.config.suppress_tokens = [] training_args = Seq2SeqTrainingArguments( output_dir="seqtoseq-trained", gradient_checkpointing=True, per_device_train_batch_size=2, learning_rate=1e-5, warmup_steps=2, max_steps=2000, fp16=True ,#False ,#True, optim='adafactor',

group_by_length=True,

predict_with_generate=True,
evaluation_strategy="steps",
per_device_eval_batch_size=2,
eval_steps=100,
load_best_model_at_end=True,
metric_for_best_model="wer",
report_to = ["tensorboard"],
#data_parallel=False

)

trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=encoded_dataset["train"], eval_dataset=encoded_dataset["test"], tokenizer=processor, data_collator=data_collator, compute_metrics=compute_metrics, #data_parallel=False

sampler = None

)

`

**Step 6: Define Computation metric

import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):

wer = evaluate.load("wer")

pred_ids = pred.predictions
label_ids = pred.label_ids

# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

wer = 100 * metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}

`

To start training run the below command

Python3 `

Requires GPU for training

trainer.train()

`

**Output:

Logs of training: Step Training Loss Validation Loss Wer 100 No log 0.525059 4.193548 200 No log 0.532363 1.846774 300 No log 0.553872 1.161290 400 No log 0.568876 1.161290 500 0.000000 0.590014 1.169355

**Step 7: Drawing inferences

Let us check the output of our model after training

Python3 `

getting test data

inputs = processor(dataset['test'][8]["audio"]["array"], sampling_rate=16000, return_tensors="pt").to(device).input_features print(f"The input test audio is: {dataset['test'][8]['transcription']}")

generated_ids = model.generate(inputs=inputs)

transcription = processor.batch_decode( generated_ids, skip_special_tokens=True)[0] print(f'The output prediction is : {transcription}')

`

**Output:

The input test audio is: how much do I have in my account The output prediction is : 'm much do I have in my account

Conclusion:

In this article, we saw how to fine-tune an audio seq2seq model using the transformers library.