UNet2DConditionModel (original) (raw)

The UNet model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in πŸ€— Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in πŸ€— Diffusers, depending on it’s number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model.

There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

class diffusers.UNet2DConditionModel

< source >

( sample_size: typing.Union[int, typing.Tuple[int, int], NoneType] = None in_channels: int = 4 out_channels: int = 4 center_input_sample: bool = False flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: typing.Union[int, typing.Tuple[int]] = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 dropout: float = 0.0 act_fn: str = 'silu' norm_num_groups: typing.Optional[int] = 32 norm_eps: float = 1e-05 cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280 transformer_layers_per_block: typing.Union[int, typing.Tuple[int], typing.Tuple[typing.Tuple]] = 1 reverse_transformer_layers_per_block: typing.Optional[typing.Tuple[typing.Tuple[int]]] = None encoder_hid_dim: typing.Optional[int] = None encoder_hid_dim_type: typing.Optional[str] = None attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = None dual_cross_attention: bool = False use_linear_projection: bool = False class_embed_type: typing.Optional[str] = None addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None num_class_embeds: typing.Optional[int] = None upcast_attention: bool = False resnet_time_scale_shift: str = 'default' resnet_skip_time_act: bool = False resnet_out_scale_factor: float = 1.0 time_embedding_type: str = 'positional' time_embedding_dim: typing.Optional[int] = None time_embedding_act_fn: typing.Optional[str] = None timestep_post_act: typing.Optional[str] = None time_cond_proj_dim: typing.Optional[int] = None conv_in_kernel: int = 3 conv_out_kernel: int = 3 projection_class_embeddings_input_dim: typing.Optional[int] = None attention_type: str = 'default' class_embeddings_concat: bool = False mid_block_only_cross_attention: typing.Optional[bool] = None cross_attention_norm: typing.Optional[str] = None addition_embed_type_num_heads: int = 64 )

Parameters

A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output.

This model inherits from ModelMixin. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving).

Disables the FreeU mechanism.

enable_freeu

< source >

( s1: float s2: float b1: float b2: float )

Parameters

Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

The suffixes after the scaling factors represent the stage blocks where they are being applied.

Please refer to the official repository for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

forward

< source >

( sample: Tensor timestep: typing.Union[torch.Tensor, float, int] encoder_hidden_states: Tensor class_labels: typing.Optional[torch.Tensor] = None timestep_cond: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None added_cond_kwargs: typing.Optional[typing.Dict[str, torch.Tensor]] = None down_block_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None mid_block_additional_residual: typing.Optional[torch.Tensor] = None down_intrablock_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None return_dict: bool = True ) β†’ UNet2DConditionOutput or tuple

Parameters

If return_dict is True, an UNet2DConditionOutput is returned, otherwise a tuple is returned where the first element is the sample tensor.

The UNet2DConditionModel forward method.

Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

This API is πŸ§ͺ experimental.

set_attention_slice

< source >

( slice_size: typing.Union[str, int, typing.List[int]] = 'auto' )

Parameters

Enable sliced attention computation.

When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed.

set_attn_processor

< source >

( processor: typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor, typing.Dict[str, typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor]]] )

Parameters

Sets the attention processor to use to compute attention.

Disables custom attention processors and sets the default attention implementation.

Disables the fused QKV projection if enabled.

This API is πŸ§ͺ experimental.

class diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput

< source >

( sample: Tensor = None )

Parameters

The output of UNet2DConditionModel.

class diffusers.FlaxUNet2DConditionModel

< source >

( sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: typing.Tuple[str, ...] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str, ...] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: typing.Union[int, typing.Tuple[int, ...]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int, ...], NoneType] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: dtype = <class 'jax.numpy.float32'> flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False transformer_layers_per_block: typing.Union[int, typing.Tuple[int, ...]] = 1 addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: typing.Optional[int] = None parent: typing.Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f0cca976590> name: typing.Optional[str] = None )

Parameters

A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output.

This model inherits from FlaxModelMixin. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving).

This model is also a Flax Linen flax.linen.Modulesubclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior.

Inherent JAX features such as the following are supported:

class diffusers.models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput

< source >

( sample: Array )

Parameters

The output of FlaxUNet2DConditionModel.

Returns a new object replacing the specified fields with new values.