M-CTC-T (original) (raw)

PyTorch

This model is in maintenance mode only, so we won’t accept any new PRs changing its code.

If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0. You can do so by running the following command: pip install -U transformers==4.30.0.

Overview

The M-CTC-T model was proposed in Pseudo-Labeling For Massively Multilingual Speech Recognition by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. The model is a 1B-param transformer encoder, with a CTC head over 8065 character labels and a language identification head over 60 language ID labels. It is trained on Common Voice (version 6.1, December 2020 release) and VoxPopuli. After training on Common Voice and VoxPopuli, the model is trained on Common Voice only. The labels are unnormalized character-level transcripts (punctuation and capitalization are not removed). The model takes as input Mel filterbank features from a 16Khz audio signal.

The abstract from the paper is the following:

Semi-supervised learning through pseudo-labeling has become a staple of state-of-the-art monolingual speech recognition systems. In this work, we extend pseudo-labeling to massively multilingual speech recognition with 60 languages. We propose a simple pseudo-labeling recipe that works well even with low-resource languages: train a supervised multilingual model, fine-tune it with semi-supervised learning on a target language, generate pseudo-labels for that language, and train a final model using pseudo-labels for all languages, either from scratch or by fine-tuning. Experiments on the labeled Common Voice and unlabeled VoxPopuli datasets show that our recipe can yield a model with better performance for many languages that also transfers well to LibriSpeech.

This model was contributed by cwkeam. The original code can be found here.

Usage tips

The PyTorch version of this model is only available in torch 1.9 and higher.

Resources

MCTCTConfig

class transformers.MCTCTConfig

< source >

( vocab_size = 8065 hidden_size = 1536 num_hidden_layers = 36 intermediate_size = 6144 num_attention_heads = 4 attention_head_dim = 384 max_position_embeddings = 920 layer_norm_eps = 1e-05 layerdrop = 0.3 hidden_act = 'relu' initializer_range = 0.02 hidden_dropout_prob = 0.3 attention_probs_dropout_prob = 0.3 pad_token_id = 1 bos_token_id = 0 eos_token_id = 2 conv_glu_dim = 1 conv_dropout = 0.3 num_conv_layers = 1 conv_kernel = (7,) conv_stride = (3,) input_feat_per_channel = 80 input_channels = 1 conv_channels = None ctc_loss_reduction = 'sum' ctc_zero_infinity = False **kwargs )

Parameters

This is the configuration class to store the configuration of a MCTCTModel. It is used to instantiate an M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the M-CTC-Tspeechbrain/m-ctc-t-large architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import MCTCTConfig, MCTCTModel

configuration = MCTCTConfig()

model = MCTCTModel(configuration)

configuration = model.config

MCTCTFeatureExtractor

( feature_size = 80 sampling_rate = 16000 padding_value = 0.0 hop_length = 10 win_length = 25 win_function = 'hamming_window' frame_signal_scale = 32768.0 preemphasis_coeff = 0.97 mel_floor = 1.0 normalize_means = True normalize_vars = True return_attention_mask = False **kwargs )

Parameters

Constructs a M-CTC-T feature extractor.

This feature extractor inherits from SequenceFeatureExtractor which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. This code has been adapted from Flashlight’s C++ code. For more information about the implementation, one can refer to this notebookthat takes the user step-by-step in the implementation.

( raw_speech: typing.Union[numpy.ndarray, typing.List[float], typing.List[numpy.ndarray], typing.List[typing.List[float]]] padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False max_length: typing.Optional[int] = None truncation: bool = False pad_to_multiple_of: typing.Optional[int] = None return_attention_mask: typing.Optional[bool] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None sampling_rate: typing.Optional[int] = None **kwargs )

Parameters

Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.

MCTCTProcessor

class transformers.MCTCTProcessor

< source >

( feature_extractor tokenizer )

Parameters

Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.

MCTCTProcessor offers all the functionalities of MCTCTFeatureExtractor and AutoTokenizer. See thecall() and decode() for more information.

When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor’scall() and returns its output. If used in the contextas_target_processor() this method forwards all its arguments to AutoTokenizer’s__call__(). Please refer to the docstring of the above two methods for more information.

from_pretrained

< source >

( pretrained_model_name_or_path: typing.Union[str, os.PathLike] cache_dir: typing.Union[str, os.PathLike, NoneType] = None force_download: bool = False local_files_only: bool = False token: typing.Union[bool, str, NoneType] = None revision: str = 'main' **kwargs )

Parameters

Instantiate a processor associated with a pretrained model.

This class method is simply calling the feature extractorfrom_pretrained(), image processorImageProcessingMixin and the tokenizer~tokenization_utils_base.PreTrainedTokenizer.from_pretrained methods. Please refer to the docstrings of the methods above for more information.

save_pretrained

< source >

( save_directory push_to_hub: bool = False **kwargs )

Parameters

Saves the attributes of this processor (feature extractor, tokenizer…) in the specified directory so that it can be reloaded using the from_pretrained() method.

This class method is simply calling save_pretrained() andsave_pretrained(). Please refer to the docstrings of the methods above for more information.

This method forwards all its arguments to AutoTokenizer’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to AutoTokenizer’s decode(). Please refer to the docstring of this method for more information.

MCTCTModel

class transformers.MCTCTModel

< source >

( config )

Parameters

The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_features: Tensor attention_mask: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MCTCTConfig) and inputs.

The MCTCTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, MCTCTModel from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("speechbrain/m-ctc-t-large") model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): ... logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids) transcription[0] [1, 195, 1536]

inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids

loss = model(**inputs).loss

MCTCTForCTC

class transformers.MCTCTForCTC

< source >

( config )

Parameters

MCTCT Model with a language modeling head on top for Connectionist Temporal Classification (CTC). This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_features: Tensor attention_mask: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None labels: typing.Optional[torch.LongTensor] = None ) → transformers.modeling_outputs.CausalLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MCTCTConfig) and inputs.

The MCTCTForCTC forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, MCTCTForCTC from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("speechbrain/m-ctc-t-large") model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): ... logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids) transcription[0] "Mr. Quilter is the apostle of the middle classes, and we're glad to welcome his gospel."

inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids

loss = model(**inputs).loss round(loss.item(), 2) 1885.65

< > Update on GitHub