Chroma (original) (raw)

LoRA MPS

Chroma is a text to image generation model based on Flux.

Original model checkpoints for Chroma can be found here:

Chroma can use all the same optimizations as Flux.

Inference

import torch from diffusers import ChromaPipeline

pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload()

prompt = [ "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." ] negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]

image = pipe( prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator("cpu").manual_seed(433), num_inference_steps=40, guidance_scale=3.0, num_images_per_prompt=1, ).images[0] image.save("chroma.png")

Loading from a single file

To use updated model checkpoints that are not in the Diffusers format, you can use the ChromaTransformer2DModel class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.

The following example demonstrates how to run Chroma from a single file.

Then run the following example

import torch from diffusers import ChromaTransformer2DModel, ChromaPipeline

model_id = "lodestones/Chroma1-HD" dtype = torch.bfloat16

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype) pipe.enable_model_cpu_offload()

prompt = [ "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." ] negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]

image = pipe( prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator("cpu").manual_seed(433), num_inference_steps=40, guidance_scale=3.0, ).images[0]

image.save("chroma-single-file.png")

ChromaPipeline

class diffusers.ChromaPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKL text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: ChromaTransformer2DModel image_encoder: CLIPVisionModelWithProjection = None feature_extractor: CLIPImageProcessor = None )

Parameters

The Chroma pipeline for text-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma1-HD/

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 35 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: float = 5.0 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None negative_ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None negative_ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True joint_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.chroma.ChromaPipelineOutput or tuple

Parameters

Returns

~pipelines.chroma.ChromaPipelineOutput or tuple

~pipelines.chroma.ChromaPipelineOutput ifreturn_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import ChromaPipeline

model_id = "lodestones/Chroma1-HD" ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) pipe = ChromaPipeline.from_pretrained( ... model_id, ... transformer=transformer, ... torch_dtype=torch.bfloat16, ... ) pipe.enable_model_cpu_offload() prompt = [ ... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." ... ] negative_prompt = [ ... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" ... ] image = pipe(prompt, negative_prompt=negative_prompt).images[0] image.save("chroma.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str]] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None do_classifier_free_guidance: bool = True max_sequence_length: int = 512 lora_scale: typing.Optional[float] = None )

Parameters

ChromaImg2ImgPipeline

class diffusers.ChromaImg2ImgPipeline

< source >

( scheduler: FlowMatchEulerDiscreteScheduler vae: AutoencoderKL text_encoder: T5EncoderModel tokenizer: T5TokenizerFast transformer: ChromaTransformer2DModel image_encoder: CLIPVisionModelWithProjection = None feature_extractor: CLIPImageProcessor = None )

Parameters

The Chroma pipeline for image-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma1-HD/

__call__

< source >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 35 sigmas: typing.Optional[typing.List[float]] = None guidance_scale: float = 5.0 strength: float = 0.9 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None negative_ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None negative_ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[<built-in method tensor of type object at 0x7fd11cb836a0>] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True joint_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~pipelines.chroma.ChromaPipelineOutput or tuple

Parameters

Returns

~pipelines.chroma.ChromaPipelineOutput or tuple

~pipelines.chroma.ChromaPipelineOutput ifreturn_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

import torch from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline

model_id = "lodestones/Chroma1-HD" ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" pipe = ChromaImg2ImgPipeline.from_pretrained( ... model_id, ... transformer=transformer, ... torch_dtype=torch.bfloat16, ... ) pipe.enable_model_cpu_offload() init_image = load_image( ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" ... ) prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution" negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0] image.save("chroma-img2img.png")

Disable sliced VAE decoding. If enable_vae_slicing was previously enabled, this method will go back to computing decoding in one step.

Disable tiled VAE decoding. If enable_vae_tiling was previously enabled, this method will go back to computing decoding in one step.

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

encode_prompt

< source >

( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str]] = None device: typing.Optional[torch.device] = None num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None do_classifier_free_guidance: bool = True max_sequence_length: int = 512 lora_scale: typing.Optional[float] = None )

Parameters

Update on GitHub