Audio Spectrogram Transformer (original) (raw)

PyTorch FlashAttention SDPA

Overview

The Audio Spectrogram Transformer model was proposed in AST: Audio Spectrogram Transformer by Yuan Gong, Yu-An Chung, James Glass. The Audio Spectrogram Transformer applies a Vision Transformer to audio, by turning audio into an image (spectrogram). The model obtains state-of-the-art results for audio classification.

The abstract from the paper is the following:

In the past decade, convolutional neural networks (CNNs) have been widely adopted as the main building block for end-to-end audio classification models, which aim to learn a direct mapping from audio spectrograms to corresponding labels. To better capture long-range global context, a recent trend is to add a self-attention mechanism on top of the CNN, forming a CNN-attention hybrid model. However, it is unclear whether the reliance on a CNN is necessary, and if neural networks purely based on attention are sufficient to obtain good performance in audio classification. In this paper, we answer the question by introducing the Audio Spectrogram Transformer (AST), the first convolution-free, purely attention-based model for audio classification. We evaluate AST on various audio classification benchmarks, where it achieves new state-of-the-art results of 0.485 mAP on AudioSet, 95.6% accuracy on ESC-50, and 98.1% accuracy on Speech Commands V2.

drawing Audio Spectrogram Transformer architecture. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Usage tips

Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of torch.nn.functional. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. See theofficial documentationor the GPU Inferencepage for more information.

SDPA is used by default for torch>=2.1.1 when an implementation is available, but you may also setattn_implementation="sdpa" in from_pretrained() to explicitly request SDPA to be used.

from transformers import ASTForAudioClassification model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16) ...

For the best speedups, we recommend loading the model in half-precision (e.g. torch.float16 or torch.bfloat16).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with float32 and MIT/ast-finetuned-audioset-10-10-0.4593 model, we saw the following speedups during inference.

Batch size Average inference time (ms), eager mode Average inference time (ms), sdpa model Speed up, Sdpa / Eager (x)
1 27 6 4.5
2 12 6 2
4 21 8 2.62
8 40 14 2.86

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer.

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

ASTConfig

class transformers.ASTConfig

< source >

( hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 initializer_range = 0.02 layer_norm_eps = 1e-12 patch_size = 16 qkv_bias = True frequency_stride = 10 time_stride = 10 max_length = 1024 num_mel_bins = 128 **kwargs )

Parameters

This is the configuration class to store the configuration of a ASTModel. It is used to instantiate an AST model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ASTMIT/ast-finetuned-audioset-10-10-0.4593architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import ASTConfig, ASTModel

configuration = ASTConfig()

model = ASTModel(configuration)

configuration = model.config

ASTFeatureExtractor

( feature_size = 1 sampling_rate = 16000 num_mel_bins = 128 max_length = 1024 padding_value = 0.0 do_normalize = True mean = -4.2677393 std = 4.5689974 return_attention_mask = False **kwargs )

Parameters

Constructs a Audio Spectrogram Transformer (AST) feature extractor.

This feature extractor inherits from SequenceFeatureExtractor which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.

( raw_speech: typing.Union[numpy.ndarray, typing.List[float], typing.List[numpy.ndarray], typing.List[typing.List[float]]] sampling_rate: typing.Optional[int] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None **kwargs )

Parameters

Main method to featurize and prepare for the model one or several sequence(s).

ASTModel

class transformers.ASTModel

< source >

( config: ASTConfig )

Parameters

The bare AST Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β†’ transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ASTConfig) and inputs.

The ASTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, ASTModel import torch from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): ... outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 1214, 768]

ASTForAudioClassification

class transformers.ASTForAudioClassification

< source >

( config: ASTConfig )

Parameters

Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2.

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β†’ transformers.modeling_outputs.SequenceClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.SequenceClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ASTConfig) and inputs.

The ASTForAudioClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoFeatureExtractor, ASTForAudioClassification from datasets import load_dataset import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True) dataset = dataset.sort("id") sampling_rate = dataset.features["audio"].sampling_rate

feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_class_ids = torch.argmax(logits, dim=-1).item() predicted_label = model.config.id2label[predicted_class_ids] predicted_label 'Speech'

target_label = model.config.id2label[0] inputs["labels"] = torch.tensor([model.config.label2id[target_label]]) loss = model(**inputs).loss round(loss.item(), 2) 0.17

< > Update on GitHub