aMUSEd (original) (raw)

aMUSEd was introduced in aMUSEd: An Open MUSE Reproduction by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.

Amused is a lightweight text to image model based off of the MUSE architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.

Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.

The abstract from the paper is:

We present aMUSEd, an open-source, lightweight masked image model (MIM) for text-to-image generation based on MUSE. With 10 percent of MUSE’s parameters, aMUSEd is focused on fast image generation. We believe MIM is under-explored compared to latent diffusion, the prevailing approach for text-to-image generation. Compared to latent diffusion, MIM requires fewer inference steps and is more interpretable. Additionally, MIM can be fine-tuned to learn additional styles with only a single image. We hope to encourage further exploration of MIM by demonstrating its effectiveness on large-scale text-to-image generation and releasing reproducible training code. We also release checkpoints for two models which directly produce images at 256x256 and 512x512 resolutions.

Model Params
amused-256 603M
amused-512 608M

AmusedPipeline

class diffusers.AmusedPipeline

< source >

( vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler )

__call__

< source >

( prompt: typing.Union[str, typing.List[str], NoneType] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 12 guidance_scale: float = 10.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Optional[torch._C.Generator] = None latents: typing.Optional[torch.IntTensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_encoder_hidden_states: typing.Optional[torch.Tensor] = None output_type = 'pil' return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None micro_conditioning_aesthetic_score: int = 6 micro_conditioning_crop_coord: typing.Tuple[int, int] = (0, 0) temperature: typing.Union[int, typing.Tuple[int, int], typing.List[int]] = (2, 0) ) → ImagePipelineOutput or tuple

Parameters

If return_dict is True, ImagePipelineOutput is returned, otherwise atuple is returned where the first element is a list with the generated images.

The call function to the pipeline for generation.

Examples:

import torch from diffusers import AmusedPipeline

pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0]

enable_xformers_memory_efficient_attention

< source >

( attention_op: typing.Optional[typing.Callable] = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

class diffusers.AmusedImg2ImgPipeline

< source >

( vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler )

__call__

< source >

( prompt: typing.Union[str, typing.List[str], NoneType] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None strength: float = 0.5 num_inference_steps: int = 12 guidance_scale: float = 10.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Optional[torch._C.Generator] = None prompt_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_encoder_hidden_states: typing.Optional[torch.Tensor] = None output_type = 'pil' return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None micro_conditioning_aesthetic_score: int = 6 micro_conditioning_crop_coord: typing.Tuple[int, int] = (0, 0) temperature: typing.Union[int, typing.Tuple[int, int], typing.List[int]] = (2, 0) ) → ImagePipelineOutput or tuple

Parameters

If return_dict is True, ImagePipelineOutput is returned, otherwise atuple is returned where the first element is a list with the generated images.

The call function to the pipeline for generation.

Examples:

import torch from diffusers import AmusedImg2ImgPipeline from diffusers.utils import load_image

pipe = AmusedImg2ImgPipeline.from_pretrained( ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 ... ) pipe = pipe.to("cuda")

prompt = "winter mountains" input_image = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" ... ) ... .resize((512, 512)) ... .convert("RGB") ... ) image = pipe(prompt, input_image).images[0]

enable_xformers_memory_efficient_attention

< source >

( attention_op: typing.Optional[typing.Callable] = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

class diffusers.AmusedInpaintPipeline

< source >

( vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler )

__call__

< source >

( prompt: typing.Union[str, typing.List[str], NoneType] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None mask_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None strength: float = 1.0 num_inference_steps: int = 12 guidance_scale: float = 10.0 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Optional[torch._C.Generator] = None prompt_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None negative_encoder_hidden_states: typing.Optional[torch.Tensor] = None output_type = 'pil' return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: int = 1 cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None micro_conditioning_aesthetic_score: int = 6 micro_conditioning_crop_coord: typing.Tuple[int, int] = (0, 0) temperature: typing.Union[int, typing.Tuple[int, int], typing.List[int]] = (2, 0) ) → ImagePipelineOutput or tuple

Parameters

If return_dict is True, ImagePipelineOutput is returned, otherwise atuple is returned where the first element is a list with the generated images.

The call function to the pipeline for generation.

Examples:

import torch from diffusers import AmusedInpaintPipeline from diffusers.utils import load_image

pipe = AmusedInpaintPipeline.from_pretrained( ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 ... ) pipe = pipe.to("cuda")

prompt = "fall mountains" input_image = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" ... ) ... .resize((512, 512)) ... .convert("RGB") ... ) mask = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" ... ) ... .resize((512, 512)) ... .convert("L") ... ) pipe(prompt, input_image, mask).images[0].save("out.png")

enable_xformers_memory_efficient_attention

< source >

( attention_op: typing.Optional[typing.Callable] = None )

Parameters

Enable memory efficient attention from xFormers. When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed.

⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent.

Examples:

import torch from diffusers import DiffusionPipeline from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) pipe = pipe.to("cuda") pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)

pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

disable_xformers_memory_efficient_attention

< source >

( )

Disable memory efficient attention from xFormers.

< > Update on GitHub