Stable Audio (original) (raw)

Stable Audio was proposed in Stable Audio Open by Zach Evans et al. . it takes a text prompt as input and predicts the corresponding sound or music sample.

Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder.

Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT.

The abstract of the paper is the following:Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model’s performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.

This pipeline was contributed by Yoach Lacombe. The original codebase can be found at Stability-AI/stable-audio-tools.

Tips

When constructing a prompt, keep in mind:

During inference:

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized StableAudioPipeline for inference with bitsandbytes.

import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, StableAudioDiTModel, StableAudioPipeline from diffusers.utils import export_to_video from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel

quant_config = BitsAndBytesConfig(load_in_8bit=True) text_encoder_8bit = T5EncoderModel.from_pretrained( "stabilityai/stable-audio-open-1.0", subfolder="text_encoder", quantization_config=quant_config, torch_dtype=torch.float16, )

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) transformer_8bit = StableAudioDiTModel.from_pretrained( "stabilityai/stable-audio-open-1.0", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.float16, )

pipeline = StableAudioPipeline.from_pretrained( "stabilityai/stable-audio-open-1.0", text_encoder=text_encoder_8bit, transformer=transformer_8bit, torch_dtype=torch.float16, device_map="balanced", )

prompt = "The sound of a hammer hitting a wooden surface." negative_prompt = "Low quality." audio = pipeline( prompt, negative_prompt=negative_prompt, num_inference_steps=200, audio_end_in_s=10.0, num_waveforms_per_prompt=3, generator=generator, ).audios

output = audio[0].T.float().cpu().numpy() sf.write("hammer.wav", output, pipeline.vae.sampling_rate)

StableAudioPipeline

class diffusers.StableAudioPipeline

< source >

( vae: AutoencoderOobleck text_encoder: T5EncoderModel projection_model: StableAudioProjectionModel tokenizer: typing.Union[transformers.models.t5.tokenization_t5.T5Tokenizer, transformers.models.t5.tokenization_t5_fast.T5TokenizerFast] transformer: StableAudioDiTModel scheduler: EDMDPMSolverMultistepScheduler )

Parameters

Pipeline for text-to-audio generation using StableAudio.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None audio_end_in_s: typing.Optional[float] = None audio_start_in_s: typing.Optional[float] = 0.0 num_inference_steps: int = 100 guidance_scale: float = 7.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_waveforms_per_prompt: typing.Optional[int] = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None initial_audio_waveforms: typing.Optional[torch.Tensor] = None initial_audio_sampling_rate: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.LongTensor] = None negative_attention_mask: typing.Optional[torch.LongTensor] = None return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: typing.Optional[int] = 1 output_type: typing.Optional[str] = 'pt' ) → StableDiffusionPipelineOutput or tuple

Parameters

If return_dict is True, StableDiffusionPipelineOutput is returned, otherwise a tuple is returned where the first element is a list with the generated audio.

The call function to the pipeline for generation.

Examples:

import scipy import torch import soundfile as sf from diffusers import StableAudioPipeline

repo_id = "stabilityai/stable-audio-open-1.0" pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) pipe = pipe.to("cuda")

prompt = "The sound of a hammer hitting a wooden surface." negative_prompt = "Low quality."

generator = torch.Generator("cuda").manual_seed(0)

audio = pipe( ... prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=200, ... audio_end_in_s=10.0, ... num_waveforms_per_prompt=3, ... generator=generator, ... ).audios

output = audio[0].T.float().cpu().numpy() sf.write("hammer.wav", output, pipe.vae.sampling_rate)

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

< > Update on GitHub