From 1051ca81a60073702320b20eb633b178c3dd1c9b Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 6 Feb 2023 22:11:57 -1000 Subject: [PATCH] Stable Diffusion Latent Upscaler (#2059) * Modify UNet2DConditionModel - allow skipping mid_block - adding a norm_group_size argument so that we can set the `num_groups` for group norm using `num_channels//norm_group_size` - allow user to set dimension for the timestep embedding (`time_embed_dim`) - the kernel_size for `conv_in` and `conv_out` is now configurable - add random fourier feature layer (`GaussianFourierProjection`) for `time_proj` - allow user to add the time and class embeddings before passing through the projection layer together - `time_embedding(t_emb + class_label))` - added 2 arguments `attn1_types` and `attn2_types` * currently we have argument `only_cross_attention`: when it's set to `True`, we will have a to the `BasicTransformerBlock` block with 2 cross-attention , otherwise we get a self-attention followed by a cross-attention; in k-upscaler, we need to have blocks that include just one cross-attention, or self-attention -> cross-attention; so I added `attn1_types` and `attn2_types` to the unet's argument list to allow user specify the attention types for the 2 positions in each block; note that I stil kept the `only_cross_attention` argument for unet for easy configuration, but it will be converted to `attn1_type` and `attn2_type` when passing down to the down blocks - the position of downsample layer and upsample layer is now configurable - in k-upscaler unet, there is only one skip connection per each up/down block (instead of each layer in stable diffusion unet), added `skip_freq = "block"` to support this use case - if user passes attention_mask to unet, it will prepare the mask and pass a flag to cross attention processer to skip the `prepare_attention_mask` step inside cross attention block add up/down blocks for k-upscaler modify CrossAttention class - make the `dropout` layer in `to_out` optional - `use_conv_proj` - use conv instead of linear for all projection layers (i.e. `to_q`, `to_k`, `to_v`, `to_out`) whenever possible. note that when it's used to do cross attention, to_k, to_v has to be linear because the `encoder_hidden_states` is not 2d - `cross_attention_norm` - add an optional layernorm on encoder_hidden_states - `attention_dropout`: add an optional dropout on attention score adapt BasicTransformerBlock - add an ada groupnorm layer to conditioning attention input with timestep embedding - allow skipping the FeedForward layer in between the attentions - replaced the only_cross_attention argument with attn1_type and attn2_type for more flexible configuration update timestep embedding: add new act_fn gelu and an optional act_2 modified ResnetBlock2D - refactored with AdaGroupNorm class (the timestep scale shift normalization) - add `mid_channel` argument - allow the first conv to have a different output dimension from the second conv - add option to use input AdaGroupNorm on the input instead of groupnorm - add options to add a dropout layer after each conv - allow user to set the bias in conv_shortcut (needed for k-upscaler) - add gelu adding conversion script for k-upscaler unet add pipeline * fix attention mask * fix a typo * fix a bug * make sure model can be used with GPU * make pipeline work with fp16 * fix an error in BasicTransfomerBlock * make style * fix typo * some more fixes * uP * up * correct more * some clean-up * clean time proj * up * uP * more changes * remove the upcast_attention=True from unet config * remove attn1_types, attn2_types etc * fix * revert incorrect changes up/down samplers * make style * remove outdated files * Apply suggestions from code review * attention refactor * refactor cross attention * Apply suggestions from code review * update * up * update * Apply suggestions from code review * finish * Update src/diffusers/models/cross_attention.py * more fixes * up * up * up * finish * more corrections of conversion state * act_2 -> act_2_fn * remove dropout_after_conv from ResnetBlock2D * make style * simplify KAttentionBlock * add fast test for latent upscaler pipeline * add slow test * slow test fp16 * make style * add doc string for pipeline_stable_diffusion_latent_upscale * add api doc page for latent upscaler pipeline * deprecate attention mask * clean up embeddings * simplify resnet * up * clean up resnet * up * correct more * up * up * improve a bit more * correct more * more clean-ups * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Patrick von Platen * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Patrick von Platen * add docstrings for new unet config * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen * # Copied from * encode the image if not latent * remove force casting vae to fp32 * fix * add comments about preconditioning parameters from k-diffusion paper * attn1_type, attn2_type -> add_self_attention * clean up get_down_block and get_up_block * fix * fixed a typo(?) in ada group norm * update slice attention processer for cross attention * update slice * fix fast test * update the checkpoint * finish tests * fix-copies * fix-copy for modeling_text_unet.py * make style * make style * fix f-string * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen * fix import * correct changes * fix resnet * make fix-copies * correct euler scheduler * add missing #copied from for preprocess * revert * fix * fix copies * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca * Update src/diffusers/models/cross_attention.py Co-authored-by: Pedro Cuenca * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca * clean up conversion script * KDownsample2d,KUpsample2d -> KDownsample2D,KUpsample2D * more * Update src/diffusers/models/unet_2d_condition.py Co-authored-by: Pedro Cuenca * remove prepare_extra_step_kwargs * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen * fix a typo in timestep embedding * remove num_image_per_prompt * fix fasttest * make style + fix-copies * fix * fix xformer test * fix style * doc string * make style * fix-copies * docstring for time_embedding_norm * make style * final finishes * make fix-copies * fix tests --------- Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen Co-authored-by: Pedro Cuenca --- docs/source/en/_toctree.yml | 2 + .../stable_diffusion/latent_upscale.mdx | 33 ++ .../pipelines/stable_diffusion/overview.mdx | 1 + scripts/convert_k_upscaler_to_diffusers.py | 297 ++++++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/attention.py | 35 ++ src/diffusers/models/cross_attention.py | 64 +- src/diffusers/models/embeddings.py | 40 +- src/diffusers/models/resnet.py | 117 +++- src/diffusers/models/unet_2d_blocks.py | 553 +++++++++++++++++- src/diffusers/models/unet_2d_condition.py | 103 +++- src/diffusers/pipelines/__init__.py | 1 + .../alt_diffusion/pipeline_alt_diffusion.py | 21 +- .../pipelines/stable_diffusion/__init__.py | 1 + .../pipeline_stable_diffusion.py | 21 +- ...ipeline_stable_diffusion_latent_upscale.py | 520 ++++++++++++++++ .../versatile_diffusion/modeling_text_unet.py | 103 +++- .../schedulers/scheduling_euler_discrete.py | 25 +- .../dummy_torch_and_transformers_objects.py | 15 + tests/models/test_models_unet_2d_condition.py | 2 +- .../test_stable_diffusion_latent_upscale.py | 219 +++++++ 21 files changed, 2077 insertions(+), 97 deletions(-) create mode 100644 docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx create mode 100644 scripts/convert_k_upscaler_to_diffusers.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py create mode 100644 tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 35317414..dd08e9f9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -145,6 +145,8 @@ title: Image-Variation - local: api/pipelines/stable_diffusion/upscale title: Super-Resolution + - local: api/pipelines/stable_diffusion/latent_upscale + title: Stable-Diffusion-Latent-Upscaler - local: api/pipelines/stable_diffusion/pix2pix title: InstructPix2Pix title: Stable Diffusion diff --git a/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx b/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx new file mode 100644 index 00000000..61fd2f79 --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx @@ -0,0 +1,33 @@ + + +# Stable Diffusion Latent Upscaler + +## StableDiffusionLatentUpscalePipeline + +The Stable Diffusion Latent Upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It can be used on top of any [`StableDiffusionUpscalePipeline`] checkpoint to enhance its output image resolution by a factor of 2. + +A notebook that demonstrates the original implementation can be found here: +- [Stable Diffusion Upscaler Demo](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4) + +Available Checkpoints are: +- *stabilityai/latent-upscaler*: [stabilityai/sd-x2-latent-upscaler](https://huggingface.co/stabilityai/sd-x2-latent-upscaler) + + +[[autodoc]] StableDiffusionLatentUpscalePipeline + - all + - __call__ + - enable_sequential_cpu_offload + - enable_attention_slicing + - disable_attention_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 8c417870..5d3fb77c 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -31,6 +31,7 @@ For more details about how Stable Diffusion works and how it differs from the ba | [StableDiffusionDepth2ImgPipeline](./depth2img) | **Experimental** – *Depth-to-Image Text-Guided Generation * | | Coming soon | [StableDiffusionImageVariationPipeline](./image_variation) | **Experimental** – *Image Variation Generation * | | [🤗 Stable Diffusion Image Variations](https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations) | [StableDiffusionUpscalePipeline](./upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon +| [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon | [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** – *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix) diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py new file mode 100644 index 00000000..457d9219 --- /dev/null +++ b/scripts/convert_k_upscaler_to_diffusers.py @@ -0,0 +1,297 @@ +import argparse + +import torch + +import huggingface_hub +import k_diffusion as K +from diffusers import UNet2DConditionModel + + +UPSCALER_REPO = "pcuenq/k-upscaler" + + +def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"], + f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"], + f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"], + } + ) + + return rv + + +def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0) + rv = { + # norm + f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"], + f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"], + # to_q + f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q, + # to_k + f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k, + # to_v + f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v, + # to_out + f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"], + } + + return rv + + +def cross_attn_to_diffusers_checkpoint( + checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix +): + weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0) + bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0) + + rv = { + # norm2 (ada groupnorm) + f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[ + f"{attention_prefix}.norm_dec.mapper.weight" + ], + f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[ + f"{attention_prefix}.norm_dec.mapper.bias" + ], + # layernorm on encoder_hidden_state + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[ + f"{attention_prefix}.norm_enc.weight" + ], + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[ + f"{attention_prefix}.norm_enc.bias" + ], + # to_q + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[ + f"{attention_prefix}.q_proj.weight" + ] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[ + f"{attention_prefix}.q_proj.bias" + ], + # to_k + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k, + # to_v + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v, + # to_out + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[ + f"{attention_prefix}.out_proj.weight" + ] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[ + f"{attention_prefix}.out_proj.bias" + ], + } + + return rv + + +def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): + block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks" + block_prefix = f"{block_prefix}.{block_idx}" + + diffusers_checkpoint = {} + + if not hasattr(block, "attentions"): + n = 1 # resnet only + elif not block.attentions[0].add_self_attention: + n = 2 # resnet -> cross-attention + else: + n = 3 # resnet -> self-attention -> cross-attention) + + for resnet_idx, resnet in enumerate(block.resnets): + # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}" + idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1 + resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + if hasattr(block, "attentions"): + for attention_idx, attention in enumerate(block.attentions): + diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" + idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 + self_attention_prefix = f"{block_prefix}.{idx}" + cross_attention_prefix = f"{block_prefix}.{idx }" + cross_attention_index = 1 if not attention.add_self_attention else 2 + idx = ( + n * attention_idx + cross_attention_index + if block_type == "up" + else n * attention_idx + cross_attention_index + 1 + ) + cross_attention_prefix = f"{block_prefix}.{idx }" + + diffusers_checkpoint.update( + cross_attn_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + diffusers_attention_index=2, + attention_prefix=cross_attention_prefix, + ) + ) + + if attention.add_self_attention is True: + diffusers_checkpoint.update( + self_attn_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=self_attention_prefix, + ) + ) + + return diffusers_checkpoint + + +def unet_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # pre-processing + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint["inner_model.proj_in.weight"], + "conv_in.bias": checkpoint["inner_model.proj_in.bias"], + } + ) + + # timestep and class embedding + diffusers_checkpoint.update( + { + "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1), + "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"], + "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"], + "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"], + "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"], + "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.down_blocks): + diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down")) + + # up_blocks + for up_block_idx, up_block in enumerate(model.up_blocks): + diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up")) + + # post-processing + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint["inner_model.proj_out.weight"], + "conv_out.bias": checkpoint["inner_model.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +def unet_model_from_original_config(original_config): + in_channels = original_config["input_channels"] + original_config["unet_cond_dim"] + out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0) + + block_out_channels = original_config["channels"] + + assert ( + len(set(original_config["depths"])) == 1 + ), "UNet2DConditionModel currently do not support blocks with different number of layers" + layers_per_block = original_config["depths"][0] + + class_labels_dim = original_config["mapping_cond_dim"] + cross_attention_dim = original_config["cross_cond_dim"] + + attn1_types = [] + attn2_types = [] + for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]): + if s: + a1 = "self" + a2 = "cross" if c else None + elif c: + a1 = "cross" + a2 = None + else: + a1 = None + a2 = None + attn1_types.append(a1) + attn2_types.append(a2) + + unet = UNet2DConditionModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"), + mid_block_type=None, + up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"), + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn="gelu", + norm_num_groups=None, + cross_attention_dim=cross_attention_dim, + attention_head_dim=64, + time_cond_proj_dim=class_labels_dim, + resnet_time_scale_shift="scale_shift", + time_embedding_type="fourier", + timestep_post_act="gelu", + conv_in_kernel=1, + conv_out_kernel=1, + ) + + return unet + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json") + orig_weights_path = huggingface_hub.hf_hub_download( + UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth" + ) + print(f"loading original model configuration from {orig_config_path}") + print(f"loading original model checkpoint from {orig_weights_path}") + + print("converting to diffusers unet") + orig_config = K.config.load_config(open(orig_config_path))["model"] + model = unet_model_from_original_config(orig_config) + + orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"] + converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint) + + model.load_state_dict(converted_checkpoint, strict=True) + model.save_pretrained(args.dump_path) + print(f"saving converted unet model in {args.dump_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f9803380..bc6057ea 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -115,6 +115,7 @@ else: StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionUpscalePipeline, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b5acd6f4..3cdc7177 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -480,3 +480,38 @@ class AdaLayerNormZero(nn.Module): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaGroupNorm(nn.Module): + """ + GroupNorm layer modified to incorporate timestep embeddings. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + self.act = None + if act_fn == "swish": + self.act = lambda x: F.silu(x) + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "gelu": + self.act = nn.GELU() + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x, emb): + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 43e21d3b..0602f2ee 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from torch import nn -from ..utils import logging +from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available @@ -56,6 +56,7 @@ class CrossAttention(nn.Module): bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, + cross_attention_norm: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, processor: Optional["AttnProcessor"] = None, @@ -65,6 +66,7 @@ class CrossAttention(nn.Module): cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax + self.cross_attention_norm = cross_attention_norm self.scale = dim_head**-0.5 @@ -81,6 +83,9 @@ class CrossAttention(nn.Module): else: self.group_norm = None + if cross_attention_norm: + self.norm_cross = nn.LayerNorm(cross_attention_dim) + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) @@ -224,7 +229,19 @@ class CrossAttention(nn.Module): return attention_probs - def prepare_attention_mask(self, attention_mask, target_length): + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + message=( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + head_size = self.heads if attention_mask is None: return attention_mask @@ -238,18 +255,29 @@ class CrossAttention(nn.Module): attention_mask = torch.concat([attention_mask, padding], dim=2) else: attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) return attention_mask class CrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -305,7 +333,7 @@ class LoRACrossAttnProcessor(nn.Module): self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) @@ -337,7 +365,7 @@ class CrossAttnAddedKVProcessor: batch_size, sequence_length, _ = hidden_states.shape encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -379,11 +407,15 @@ class XFormersCrossAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -417,7 +449,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() @@ -448,13 +480,17 @@ class SlicedAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) @@ -500,7 +536,7 @@ class SlicedAttnAddedKVProcessor: batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fc6cae43..28a67d7f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Optional import numpy as np import torch @@ -152,15 +153,32 @@ class PatchEmbed(nn.Module): class TimestepEmbedding(nn.Module): - def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = None + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + if act_fn == "silu": self.act = nn.SiLU() elif act_fn == "mish": self.act = nn.Mish() + elif act_fn == "gelu": + self.act = nn.GELU() + else: + raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") if out_dim is not None: time_embed_dim_out = out_dim @@ -168,13 +186,29 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - def forward(self, sample): + if post_act_fn is None: + self.post_act = None + elif post_act_fn == "silu": + self.post_act = nn.SiLU() + elif post_act_fn == "mish": + self.post_act = nn.Mish() + elif post_act_fn == "gelu": + self.post_act = nn.GELU() + else: + raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) return sample diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7037da57..7c14a7c4 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,9 +1,12 @@ from functools import partial +from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F +from .attention import AdaGroupNorm + class Upsample1D(nn.Module): """ @@ -364,7 +367,70 @@ class FirDownsample2D(nn.Module): return hidden_states +# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead +class KDownsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, (self.pad,) * 4, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv2d(x, weight, stride=2) + + +class KUpsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) + + class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + def __init__( self, *, @@ -378,12 +444,14 @@ class ResnetBlock2D(nn.Module): pre_norm=True, eps=1e-6, non_linearity="swish", - time_embedding_norm="default", + time_embedding_norm="default", # default, scale_shift, ada_group kernel=None, output_scale_factor=1.0, use_in_shortcut=None, up=False, down=False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.pre_norm = pre_norm @@ -392,40 +460,50 @@ class ResnetBlock2D(nn.Module): out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm self.up = up self.down = down self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group": + self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") - - self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) else: self.time_emb_proj = None - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) elif non_linearity == "mish": - self.nonlinearity = Mish() + self.nonlinearity = nn.Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() self.upsample = self.downsample = None if self.up: @@ -445,16 +523,22 @@ class ResnetBlock2D(nn.Module): else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") - self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias + ) def forward(self, input_tensor, temb): hidden_states = input_tensor - hidden_states = self.norm1(hidden_states) + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -470,13 +554,16 @@ class ResnetBlock2D(nn.Module): hidden_states = self.conv1(hidden_states) - if temb is not None: + if self.time_emb_proj is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - hidden_states = self.norm2(hidden_states) + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 02d5a20a..0b6a767d 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import numpy as np import torch from torch import nn -from .attention import AttentionBlock +from .attention import AdaGroupNorm, AttentionBlock from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor from .dual_transformer_2d import DualTransformer2DModel -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -168,6 +170,29 @@ def get_down_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + add_self_attention=True if not add_downsample else False, + ) raise ValueError(f"{down_block_type} does not exist.") @@ -317,6 +342,29 @@ def get_up_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{up_block_type} does not exist.") @@ -1384,6 +1432,189 @@ class SimpleCrossAttnDownBlock2D(nn.Module): return hidden_states, output_states +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample=True, + attn_num_head_channels: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + class AttnUpBlock2D(nn.Module): def __init__( self, @@ -2198,3 +2429,321 @@ class SimpleCrossAttnUpBlock2D(nn.Module): hidden_states = upsampler(hidden_states, temb) return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attn_num_head_channels=1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attn_num_head_channels + if (i == num_layers - 1) + else out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: bool = False, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + emb=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c524dbf2..ba2c09b2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .cross_attention import AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -70,9 +70,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the + mid block layer if `None`. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. @@ -80,6 +84,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. @@ -90,6 +95,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer. """ _supports_gradient_checkpointing = True @@ -109,7 +122,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) "CrossAttnDownBlock2D", "DownBlock2D", ), - mid_block_type: str = "UNetMidBlock2DCrossAttn", + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), @@ -117,7 +130,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: int = 32, + norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -127,20 +140,48 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", # fourier, positional + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, ): super().__init__() self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) # class embedding if class_embed_type is None and num_class_embeds is not None: @@ -153,7 +194,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) self.class_embedding = None self.down_blocks = nn.ModuleList([]) - self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): @@ -218,6 +258,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif mid_block_type is None: + self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") @@ -228,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -266,9 +309,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) @property def attn_processors(self) -> Dict[str, AttnProcessor]: @@ -399,6 +452,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -466,7 +520,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + + emb = self.time_embedding(t_emb, timestep_cond) if self.class_embedding is not None: if class_labels is None: @@ -498,13 +553,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_res_samples += res_samples # 4. mid - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -533,8 +589,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b69363f5..dfb2fd83 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -52,6 +52,7 @@ else: StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 007b2e9a..fb3ad40d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -632,15 +632,24 @@ class AltDiffusionPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL - if output_type == "pil": + # 10. Convert to PIL image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 71b87467..9ce3662c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -44,6 +44,7 @@ if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline + from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 116b13fe..dd878dab 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -629,15 +629,24 @@ class StableDiffusionPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL - if output_type == "pil": + # 10. Convert to PIL image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py new file mode 100644 index 00000000..d025f104 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -0,0 +1,520 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +import PIL +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import is_accelerate_available, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): + r""" + Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`EulerDiscreteScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: EulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_out = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + text_embeddings = text_encoder_out.hidden_states[-1] + text_pooler_out = text_encoder_out.pooler_output + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + + uncond_encoder_out = self.text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + uncond_embeddings = uncond_encoder_out.hidden_states[-1] + uncond_pooler_out = uncond_encoder_out.pooler_output + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) + + return text_embeddings, text_pooler_out + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, image, noise_level, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] if image.ndim == 4 else 1 + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image upscaling. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be + either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It + will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an + image representation and encoded using this pipeline's `vae` encoder. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + ```py + >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline + >>> import torch + + + >>> pipeline = StableDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ) + >>> pipeline.to("cuda") + + >>> model_id = "stabilityai/sd-x2-latent-upscaler" + >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> upscaler.to("cuda") + + >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" + >>> generator = torch.manual_seed(33) + + >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images + + >>> with torch.no_grad(): + ... image = pipeline.decode_latents(low_res_latents) + >>> image = pipeline.numpy_to_pil(image)[0] + + >>> image.save("../images/a1.png") + + >>> upscaled_image = upscaler( + ... prompt=prompt, + ... image=low_res_latents, + ... num_inference_steps=20, + ... guidance_scale=0, + ... generator=generator, + ... ).images[0] + + >>> upscaled_image.save("../images/a2.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs(prompt, image, noise_level, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if guidance_scale == 0: + prompt = [""] * batch_size + + # 3. Encode input prompt + text_embeddings, text_pooler_out = self._encode_prompt( + prompt, device, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = preprocess(image) + image = image.to(dtype=text_embeddings.dtype, device=device) + if image.shape[1] == 3: + # encode image if not in latent-space yet + image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = image[None, :] if image.ndim == 3 else image + image = torch.cat([image] * batch_multiplier) + + # 5. Add noise to image (set to be 0): + # (see below notes from the author): + # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." + noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) + noise_level = torch.cat([noise_level] * image.shape[0]) + inv_noise_level = (noise_level**2 + 1) ** (-0.5) + + image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] + image_cond = image_cond.to(text_embeddings.dtype) + + noise_level_embed = torch.cat( + [ + torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), + torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), + ], + dim=1, + ) + + timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height * 2, # 2x upscale + width * 2, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Denoising loop + num_warmup_steps = 0 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + sigma = self.scheduler.sigmas[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) + # preconditioning parameter based on Karras et al. (2022) (table 1) + timestep = torch.log(sigma) * 0.25 + + noise_pred = self.unet( + scaled_model_input, + timestep, + encoder_hidden_states=text_embeddings, + timestep_cond=timestep_condition, + ).sample + + # in original repo, the output contains a variance channel that's not used + noise_pred = noise_pred[:, :-1] + + # apply preconditioning, based on table 1 in Karras et al. (2022) + inv_sigma = 1 / (sigma**2 + 1) + noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3511cbfd..806875a9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -9,7 +9,7 @@ from ...models import ModelMixin from ...models.attention import CrossAttention from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.dual_transformer_2d import DualTransformer2DModel -from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -151,9 +151,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`. + The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip + the mid block layer if `None`. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. @@ -161,6 +165,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. @@ -171,6 +176,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer. """ _supports_gradient_checkpointing = True @@ -190,7 +203,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): "CrossAttnDownBlockFlat", "DownBlockFlat", ), - mid_block_type: str = "UNetMidBlockFlatCrossAttn", + mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", up_block_types: Tuple[str] = ( "UpBlockFlat", "CrossAttnUpBlockFlat", @@ -203,7 +216,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: int = 32, + norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -213,20 +226,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", # fourier, positional + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, ): super().__init__() self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 # input - self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) # class embedding if class_embed_type is None and num_class_embeds is not None: @@ -239,7 +280,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): self.class_embedding = None self.down_blocks = nn.ModuleList([]) - self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): @@ -304,6 +344,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif mid_block_type is None: + self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") @@ -314,6 +356,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -352,9 +395,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) @property def attn_processors(self) -> Dict[str, AttnProcessor]: @@ -485,6 +538,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -552,7 +606,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + + emb = self.time_embedding(t_emb, timestep_cond) if self.class_embedding is not None: if class_labels is None: @@ -584,13 +639,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): down_block_res_samples += res_samples # 4. mid - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -619,8 +675,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 02e5c2cd..1a7a46bc 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -65,11 +65,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `epsilon`, optional): + prediction_type (`str`, default `"epsilon"`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - + interpolation_type (`str`, default `"linear"`, optional): + interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of + [`"linear"`, `"log_linear"`]. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -84,6 +86,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + interpolation_type: str = "linear", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -130,7 +133,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True return sample @@ -148,7 +153,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): @@ -230,7 +245,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": + if self.config.prediction_type == "original_sample": + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 14aee55e..33fc0c72 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -169,6 +169,21 @@ class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 4fe3ab51..feb2de6e 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -252,7 +252,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py new file mode 100644 index 00000000..e19594df --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + EulerDiscreteScheduler, + StableDiffusionLatentUpscalePipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionLatentUpscalePipeline + test_cpu_offload = True + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 4 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + def get_dummy_components(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + act_fn="gelu", + attention_head_dim=8, + norm_num_groups=None, + block_out_channels=[32, 32, 64, 64], + time_cond_proj_dim=160, + conv_in_kernel=1, + conv_out_kernel=1, + cross_attention_dim=32, + down_block_types=( + "KDownBlock2D", + "KCrossAttnDownBlock2D", + "KCrossAttnDownBlock2D", + "KCrossAttnDownBlock2D", + ), + in_channels=8, + mid_block_type=None, + only_cross_attention=False, + out_channels=5, + resnet_time_scale_shift="scale_shift", + time_embedding_type="fourier", + timestep_post_act="gelu", + up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"), + ) + vae = AutoencoderKL( + block_out_channels=[32, 32, 64, 64], + in_channels=3, + out_channels=3, + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + scheduler = EulerDiscreteScheduler(prediction_type="original_sample") + text_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="quick_gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": model.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": self.dummy_image.cpu(), + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 256, 256, 3)) + expected_slice = np.array( + [0.47222412, 0.41921633, 0.44717434, 0.46874192, 0.42588258, 0.46150726, 0.4677534, 0.45583832, 0.48579055] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(relax_max_difference=False) + + +@require_torch_gpu +@slow +class StableDiffusionLatentUpscalePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_latent_upscaler_fp16(self): + generator = torch.manual_seed(33) + + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) + pipe.to("cuda") + + upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( + "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16 + ) + upscaler.to("cuda") + + prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" + + low_res_latents = pipe(prompt, generator=generator, output_type="latent").images + + image = upscaler( + prompt=prompt, + image=low_res_latents, + num_inference_steps=20, + guidance_scale=0, + generator=generator, + output_type="np", + ).images[0] + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/astronaut_1024.npy" + ) + assert np.abs((expected_image - image).max()) < 1e-3 + + def test_latent_upscaler_fp16_image(self): + generator = torch.manual_seed(33) + + upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( + "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16 + ) + upscaler.to("cuda") + + prompt = "the temple of fire by Ross Tran and Gerardo Dottori, oil on canvas" + + low_res_img = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/fire_temple_512.png" + ) + + image = upscaler( + prompt=prompt, + image=low_res_img, + num_inference_steps=20, + guidance_scale=0, + generator=generator, + output_type="np", + ).images[0] + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/fire_temple_1024.npy" + ) + assert np.abs((expected_image - image).max()) < 1e-3