Wav2Vec2-Conformer (original) (raw)

PyTorch

Overview

The Wav2Vec2-Conformer was added to an updated version of fairseq S2T: Fast Speech-to-Text Modeling with fairseq by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.

The official results of the model can be found in Table 3 and Table 4 of the paper.

The Wav2Vec2-Conformer weights were released by the Meta AI team within the Fairseq library.

This model was contributed by patrickvonplaten. The original code can be found here.

Note: Meta (FAIR) released a new version of Wav2Vec2-BERT 2.0 - it’s pretrained on 4.5M hours of audio. We especially recommend using it for fine-tuning tasks, e.g. as per this guide.

Usage tips

Resources

Wav2Vec2ConformerConfig

class transformers.Wav2Vec2ConformerConfig

< source >

( vocab_size = None hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout = 0.1 activation_dropout = 0.1 attention_dropout = 0.1 feat_proj_dropout = 0.0 feat_quantizer_dropout = 0.0 final_dropout = 0.1 layerdrop = 0.1 initializer_range = 0.02 layer_norm_eps = 1e-05 feat_extract_norm = 'group' feat_extract_activation = 'gelu' conv_dim = (512, 512, 512, 512, 512, 512, 512) conv_stride = (5, 2, 2, 2, 2, 2, 2) conv_kernel = (10, 3, 3, 3, 3, 2, 2) conv_bias = False num_conv_pos_embeddings = 128 num_conv_pos_embedding_groups = 16 apply_spec_augment = True mask_time_prob = 0.05 mask_time_length = 10 mask_time_min_masks = 2 mask_feature_prob = 0.0 mask_feature_length = 10 mask_feature_min_masks = 0 num_codevectors_per_group = 320 num_codevector_groups = 2 contrastive_logits_temperature = 0.1 num_negatives = 100 codevector_dim = 256 proj_codevector_dim = 256 diversity_loss_weight = 0.1 ctc_loss_reduction = 'sum' ctc_zero_infinity = False use_weighted_layer_sum = False classifier_proj_size = 256 tdnn_dim = (512, 512, 512, 512, 1500) tdnn_kernel = (5, 3, 3, 1, 1) tdnn_dilation = (1, 2, 3, 1, 1) xvector_output_dim = 512 pad_token_id = 0 bos_token_id = 1 eos_token_id = 2 add_adapter = False adapter_kernel_size = 3 adapter_stride = 2 num_adapter_layers = 3 output_hidden_size = None position_embeddings_type = 'relative' rotary_embedding_base = 10000 max_source_positions = 5000 conv_depthwise_kernel_size = 31 conformer_conv_dropout = 0.1 **kwargs )

Parameters

This is the configuration class to store the configuration of a Wav2Vec2ConformerModel. It is used to instantiate an Wav2Vec2Conformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformerfacebook/wav2vec2-conformer-rel-pos-largearchitecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Wav2Vec2ConformerConfig, Wav2Vec2ConformerModel

configuration = Wav2Vec2ConformerConfig()

model = Wav2Vec2ConformerModel(configuration)

configuration = model.config

Wav2Vec2Conformer specific outputs

class transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForPreTrainingOutput

< source >

( loss: typing.Optional[torch.FloatTensor] = None projected_states: typing.Optional[torch.FloatTensor] = None projected_quantized_states: typing.Optional[torch.FloatTensor] = None codevector_perplexity: typing.Optional[torch.FloatTensor] = None hidden_states: typing.Optional[typing.Tuple[torch.FloatTensor]] = None attentions: typing.Optional[typing.Tuple[torch.FloatTensor]] = None contrastive_loss: typing.Optional[torch.FloatTensor] = None diversity_loss: typing.Optional[torch.FloatTensor] = None )

Parameters

Output type of Wav2Vec2ConformerForPreTraining, with potential hidden states and attentions.

Wav2Vec2ConformerModel

class transformers.Wav2Vec2ConformerModel

< source >

( config: Wav2Vec2ConformerConfig )

Parameters

The bare Wav2Vec2 Conformer Model outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None mask_time_indices: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.Wav2Vec2BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.Wav2Vec2BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Wav2Vec2ConformerForCTC

class transformers.Wav2Vec2ConformerForCTC

< source >

( config target_lang: typing.Optional[str] = None )

Parameters

Wav2Vec2Conformer Model with a language modeling head on top for Connectionist Temporal Classification (CTC).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None labels: typing.Optional[torch.Tensor] = None ) → transformers.modeling_outputs.CausalLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerForCTC forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, Wav2Vec2ConformerForCTC from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): ... logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.batch_decode(predicted_ids) transcription[0] ...

inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids

loss = model(**inputs).loss round(loss.item(), 2) ...

Wav2Vec2ConformerForSequenceClassification

class transformers.Wav2Vec2ConformerForSequenceClassification

< source >

( config )

Parameters

Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None labels: typing.Optional[torch.Tensor] = None ) → transformers.modeling_outputs.SequenceClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.SequenceClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example of single-label classification:

import torch from transformers import AutoTokenizer, Wav2Vec2ConformerForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") model = Wav2Vec2ConformerForSequenceClassification.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_class_id = logits.argmax().item() model.config.id2label[predicted_class_id] ...

num_labels = len(model.config.id2label) model = Wav2Vec2ConformerForSequenceClassification.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large", num_labels=num_labels)

labels = torch.tensor([1]) loss = model(**inputs, labels=labels).loss round(loss.item(), 2) ...

Example of multi-label classification:

import torch from transformers import AutoTokenizer, Wav2Vec2ConformerForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") model = Wav2Vec2ConformerForSequenceClassification.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large", problem_type="multi_label_classification")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]

num_labels = len(model.config.id2label) model = Wav2Vec2ConformerForSequenceClassification.from_pretrained( ... "facebook/wav2vec2-conformer-rel-pos-large", num_labels=num_labels, problem_type="multi_label_classification" ... )

labels = torch.sum( ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1 ... ).to(torch.float) loss = model(**inputs, labels=labels).loss

Wav2Vec2ConformerForAudioFrameClassification

class transformers.Wav2Vec2ConformerForAudioFrameClassification

< source >

( config )

Parameters

The Wav2Vec2 Conformer Model with a frame classification head on top for tasks like Speaker Diarization.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.TokenClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.TokenClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerForAudioFrameClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForAudioFrameClassification from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") model = Wav2Vec2ConformerForAudioFrameClassification.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")

inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate) with torch.no_grad(): ... logits = model(**inputs).logits

probabilities = torch.sigmoid(logits[0])

labels = (probabilities > 0.5).long() labels[0].tolist() ...

Wav2Vec2ConformerForXVector

class transformers.Wav2Vec2ConformerForXVector

< source >

( config )

Parameters

Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None labels: typing.Optional[torch.Tensor] = None ) → transformers.modeling_outputs.XVectorOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.XVectorOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerForXVector forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForXVector from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") model = Wav2Vec2ConformerForXVector.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")

inputs = feature_extractor( ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True ... ) with torch.no_grad(): ... embeddings = model(**inputs).embeddings

embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()

cosine_sim = torch.nn.CosineSimilarity(dim=-1) similarity = cosine_sim(embeddings[0], embeddings[1]) threshold = 0.7
if similarity < threshold: ... print("Speakers are not the same!") round(similarity.item(), 2) ...

Wav2Vec2ConformerForPreTraining

class transformers.Wav2Vec2ConformerForPreTraining

< source >

( config: Wav2Vec2ConformerConfig )

Parameters

Wav2Vec2Conformer Model with a quantizer and VQ head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] attention_mask: typing.Optional[torch.Tensor] = None mask_time_indices: typing.Optional[torch.BoolTensor] = None sampled_negative_indices: typing.Optional[torch.BoolTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForPreTrainingOutput or tuple(torch.FloatTensor)

Parameters

A transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForPreTrainingOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Wav2Vec2ConformerConfig) and inputs.

The Wav2Vec2ConformerForPreTraining forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

import torch from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices, _sample_negative_indices from datasets import load_dataset

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2_conformer-base") model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2_conformer-base")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values

batch_size, raw_sequence_length = input_values.shape sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() mask_time_indices = _compute_mask_indices( ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 ... ) sampled_negative_indices = _sample_negative_indices( ... features_shape=(batch_size, sequence_length), ... num_negatives=model.config.num_negatives, ... mask_time_indices=mask_time_indices, ... ) mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) sampled_negative_indices = torch.tensor( ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long ... )

with torch.no_grad(): ... outputs = model(input_values, mask_time_indices=mask_time_indices)

cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)

cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5 tensor(True)

model = model.train() loss = model( ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices ... ).loss

< > Update on GitHub