AudioLDM 2 (original) (raw)

AudioLDM 2 was proposed in AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional sound effects, human speech and music.

Inspired by Stable Diffusion, AudioLDM 2 is a text-to-audio latent diffusion model (LDM) that learns continuous audio representations from text embeddings. Two text encoder models are used to compute the text embeddings from a prompt input: the text-branch of CLAP and the encoder of Flan-T5. These text embeddings are then projected to a shared embedding space by an AudioLDM2ProjectionModel. A GPT2 language model (LM) is used to auto-regressively predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The UNet of AudioLDM 2 is unique in the sense that it takes two cross-attention embeddings, as opposed to one cross-attention conditioning, as in most other LDMs.

The abstract of the paper is the following:

Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called “language of audio” (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at this https URL.

This pipeline was contributed by sanchit-gandhi and Nguyễn Công Tú Anh. The original codebase can be found at haoheliu/audioldm2.

Tips

Choosing a checkpoint

AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation.

All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet. See table below for details on the three checkpoints:

Checkpoint Task UNet Model Size Total Model Size Training Data / h
audioldm2 Text-to-audio 350M 1.1B 1150k
audioldm2-large Text-to-audio 750M 1.5B 1150k
audioldm2-music Text-to-music 350M 1.1B 665k
audioldm2-gigaspeech Text-to-speech 350M 1.1B 10k
audioldm2-ljspeech Text-to-speech 350M 1.1B

Constructing a prompt

Controlling inference

Evaluating generated waveforms:

The following example demonstrates how to construct good music and speech generation using the aforementioned tips: example.

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

AudioLDM2Pipeline

class diffusers.AudioLDM2Pipeline

< source >

( vae: AutoencoderKL text_encoder: ClapModel text_encoder_2: typing.Union[transformers.models.t5.modeling_t5.T5EncoderModel, transformers.models.vits.modeling_vits.VitsModel] projection_model: AudioLDM2ProjectionModel language_model: GPT2LMHeadModel tokenizer: typing.Union[transformers.models.roberta.tokenization_roberta.RobertaTokenizer, transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast] tokenizer_2: typing.Union[transformers.models.t5.tokenization_t5.T5Tokenizer, transformers.models.t5.tokenization_t5_fast.T5TokenizerFast, transformers.models.vits.tokenization_vits.VitsTokenizer] feature_extractor: ClapFeatureExtractor unet: AudioLDM2UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vocoder: SpeechT5HifiGan )

Parameters

Pipeline for text-to-audio generation using AudioLDM2.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None transcription: typing.Union[str, typing.List[str]] = None audio_length_in_s: typing.Optional[float] = None num_inference_steps: int = 200 guidance_scale: float = 3.5 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_waveforms_per_prompt: typing.Optional[int] = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None generated_prompt_embeds: typing.Optional[torch.Tensor] = None negative_generated_prompt_embeds: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.LongTensor] = None negative_attention_mask: typing.Optional[torch.LongTensor] = None max_new_tokens: typing.Optional[int] = None return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: typing.Optional[int] = 1 cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None output_type: typing.Optional[str] = 'np' ) → StableDiffusionPipelineOutput or tuple

Parameters

If return_dict is True, StableDiffusionPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated audio.

The call function to the pipeline for generation.

Examples:

import scipy import torch from diffusers import AudioLDM2Pipeline

repo_id = "cvssp/audioldm2" pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) pipe = pipe.to("cuda")

prompt = "The sound of a hammer hitting a wooden surface." negative_prompt = "Low quality."

generator = torch.Generator("cuda").manual_seed(0)

audio = pipe( ... prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... num_waveforms_per_prompt=3, ... generator=generator, ... ).audios

scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])

#Using AudioLDM2 for Text To Speech

import scipy import torch from diffusers import AudioLDM2Pipeline

repo_id = "anhnct/audioldm2_gigaspeech" pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) pipe = pipe.to("cuda")

prompt = "A female reporter is speaking" transcript = "wish you have a good day"

generator = torch.Generator("cuda").manual_seed(0)

audio = pipe( ... prompt, ... transcription=transcript, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... num_waveforms_per_prompt=2, ... generator=generator, ... max_new_tokens=512,
... ).audios

scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

enable_model_cpu_offload

< source >

( gpu_id: typing.Optional[int] = None device: typing.Union[torch.device, str] = 'cuda' )

Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to enable_sequential_cpu_offload, this method moves one whole model at a time to the GPU when its forwardmethod is called, and the model remains in GPU until the next model runs. Memory savings are lower than withenable_sequential_cpu_offload, but performance is much better due to the iterative execution of the unet.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

encode_prompt

< source >

( prompt device num_waveforms_per_prompt do_classifier_free_guidance transcription = None negative_prompt = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None generated_prompt_embeds: typing.Optional[torch.Tensor] = None negative_generated_prompt_embeds: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.LongTensor] = None negative_attention_mask: typing.Optional[torch.LongTensor] = None max_new_tokens: typing.Optional[int] = None ) → prompt_embeds (torch.Tensor)

Parameters

Returns

prompt_embeds (torch.Tensor)

Text embeddings from the Flan T5 model. attention_mask (torch.LongTensor): Attention mask to be applied to the prompt_embeds. generated_prompt_embeds (torch.Tensor): Text embeddings generated from the GPT2 langauge model.

Encodes the prompt into text encoder hidden states.

Example:

import scipy import torch from diffusers import AudioLDM2Pipeline

repo_id = "cvssp/audioldm2" pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) pipe = pipe.to("cuda")

prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt( ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs", ... device="cuda", ... do_classifier_free_guidance=True, ... )

audio = pipe( ... prompt_embeds=prompt_embeds, ... attention_mask=attention_mask, ... generated_prompt_embeds=generated_prompt_embeds, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... ).audios[0]

scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)

generate_language_model

< source >

( inputs_embeds: Tensor = None max_new_tokens: int = 8 **model_kwargs ) → inputs_embeds (torch.Tensorof shape(batch_size, sequence_length, hidden_size)`)

Parameters

Returns

inputs_embeds (torch.Tensorof shape(batch_size, sequence_length, hidden_size)`)

The sequence of generated hidden-states.

Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.

AudioLDM2ProjectionModel

class diffusers.AudioLDM2ProjectionModel

< source >

( text_encoder_dim text_encoder_1_dim langauge_model_dim use_learned_position_embedding = None max_seq_length = None )

Parameters

A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with_1 refers to that corresponding to the second text encoder. Otherwise, it is from the first.

forward

< source >

( hidden_states: typing.Optional[torch.Tensor] = None hidden_states_1: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.LongTensor] = None attention_mask_1: typing.Optional[torch.LongTensor] = None )

AudioLDM2UNet2DConditionModel

class diffusers.AudioLDM2UNet2DConditionModel

< source >

( sample_size: typing.Optional[int] = None in_channels: int = 4 out_channels: int = 4 flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: typing.Union[int, typing.Tuple[int]] = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 act_fn: str = 'silu' norm_num_groups: typing.Optional[int] = 32 norm_eps: float = 1e-05 cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280 transformer_layers_per_block: typing.Union[int, typing.Tuple[int]] = 1 attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = None use_linear_projection: bool = False class_embed_type: typing.Optional[str] = None num_class_embeds: typing.Optional[int] = None upcast_attention: bool = False resnet_time_scale_shift: str = 'default' time_embedding_type: str = 'positional' time_embedding_dim: typing.Optional[int] = None time_embedding_act_fn: typing.Optional[str] = None timestep_post_act: typing.Optional[str] = None time_cond_proj_dim: typing.Optional[int] = None conv_in_kernel: int = 3 conv_out_kernel: int = 3 projection_class_embeddings_input_dim: typing.Optional[int] = None class_embeddings_concat: bool = False )

Parameters

A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. Compared to the vanilla UNet2DConditionModel, this variant optionally includes an additional self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up to two cross-attention embeddings, encoder_hidden_states and encoder_hidden_states_1.

This model inherits from ModelMixin. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving).

forward

< source >

( sample: Tensor timestep: typing.Union[torch.Tensor, float, int] encoder_hidden_states: Tensor class_labels: typing.Optional[torch.Tensor] = None timestep_cond: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None return_dict: bool = True encoder_hidden_states_1: typing.Optional[torch.Tensor] = None encoder_attention_mask_1: typing.Optional[torch.Tensor] = None ) → UNet2DConditionOutput or tuple

Parameters

If return_dict is True, an UNet2DConditionOutput is returned, otherwise a tuple is returned where the first element is the sample tensor.

The AudioLDM2UNet2DConditionModel forward method.

AudioPipelineOutput

class diffusers.AudioPipelineOutput

< source >

( audios: ndarray )

Parameters

Output class for audio pipelines.

< > Update on GitHub