Stable Diffusion Latent Upscaler (#2059)

* Modify UNet2DConditionModel

- allow skipping mid_block

- adding a norm_group_size argument so that we can set the `num_groups` for group norm using `num_channels//norm_group_size`

- allow user to set dimension for the timestep embedding (`time_embed_dim`)

- the kernel_size for `conv_in` and `conv_out` is now configurable

- add random fourier feature layer (`GaussianFourierProjection`) for `time_proj`

- allow user to add the time and class embeddings before passing through the projection layer together - `time_embedding(t_emb + class_label))`

- added 2 arguments `attn1_types` and `attn2_types`

  * currently we have argument `only_cross_attention`: when it's set to `True`, we will have a to the
`BasicTransformerBlock` block with 2 cross-attention , otherwise we
get a self-attention followed by a cross-attention; in k-upscaler, we need to have blocks that include just one cross-attention, or self-attention -> cross-attention;
so I added `attn1_types` and `attn2_types` to the unet's argument list to allow user specify the attention types for the 2 positions in each block;  note that I stil kept
the `only_cross_attention` argument for unet for easy configuration, but it will be converted to `attn1_type` and `attn2_type` when passing down to the down blocks

- the position of downsample layer and upsample layer is now configurable

- in k-upscaler unet, there is only one skip connection per each up/down block (instead of each layer in stable diffusion unet), added `skip_freq = "block"` to support
this use case

- if user passes attention_mask to unet, it will prepare the mask and pass a flag to cross attention processer to skip the `prepare_attention_mask` step
inside cross attention block

add up/down blocks for k-upscaler

modify CrossAttention class

- make the `dropout` layer in `to_out` optional

- `use_conv_proj` - use conv instead of linear for all projection layers (i.e. `to_q`, `to_k`, `to_v`, `to_out`) whenever possible. note that when it's used to do cross
attention, to_k, to_v has to be linear because the `encoder_hidden_states` is not 2d

- `cross_attention_norm` - add an optional layernorm on encoder_hidden_states

- `attention_dropout`: add an optional dropout on attention score

adapt BasicTransformerBlock

- add an ada groupnorm layer  to conditioning attention input with timestep embedding

- allow skipping the FeedForward layer in between the attentions

- replaced the only_cross_attention argument with attn1_type and attn2_type for more flexible configuration

update timestep embedding: add new act_fn  gelu and an optional act_2

modified ResnetBlock2D

- refactored with AdaGroupNorm class (the timestep scale shift normalization)

- add `mid_channel` argument - allow the first conv to have a different output dimension from the second conv

- add option to use input AdaGroupNorm on the input instead of groupnorm

- add options to add a dropout layer after each conv

- allow user to set the bias in conv_shortcut (needed for k-upscaler)

- add gelu

adding conversion script for k-upscaler unet

add pipeline

* fix attention mask

* fix a typo

* fix a bug

* make sure model can be used with GPU

* make pipeline work with fp16

* fix an error in BasicTransfomerBlock

* make style

* fix typo

* some more fixes

* uP

* up

* correct more

* some clean-up

* clean time proj

* up

* uP

* more changes

* remove the upcast_attention=True from unet config

* remove attn1_types, attn2_types etc

* fix

* revert incorrect changes up/down samplers

* make style

* remove outdated files

* Apply suggestions from code review

* attention refactor

* refactor cross attention

* Apply suggestions from code review

* update

* up

* update

* Apply suggestions from code review

* finish

* Update src/diffusers/models/cross_attention.py

* more fixes

* up

* up

* up

* finish

* more corrections of conversion state

* act_2 -> act_2_fn

* remove dropout_after_conv from ResnetBlock2D

* make style

* simplify KAttentionBlock

* add fast test for latent upscaler pipeline

* add slow test

* slow test fp16

* make style

* add doc string for pipeline_stable_diffusion_latent_upscale

* add api doc page for latent upscaler pipeline

* deprecate attention mask

* clean up embeddings

* simplify resnet

* up

* clean up resnet

* up

* correct more

* up

* up

* improve a bit more

* correct more

* more clean-ups

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add docstrings for new unet config

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* # Copied from

* encode the image if not latent

* remove force casting vae to fp32

* fix

* add comments about preconditioning parameters from k-diffusion paper

* attn1_type, attn2_type -> add_self_attention

* clean up get_down_block and get_up_block

* fix

* fixed a typo(?) in ada group norm

* update slice attention processer for cross attention

* update slice

* fix fast test

* update the checkpoint

* finish tests

* fix-copies

* fix-copy for modeling_text_unet.py

* make style

* make style

* fix f-string

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix import

* correct changes

* fix resnet

* make fix-copies

* correct euler scheduler

* add missing #copied from for preprocess

* revert

* fix

* fix copies

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/cross_attention.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* clean up conversion script

* KDownsample2d,KUpsample2d -> KDownsample2D,KUpsample2D

* more

* Update src/diffusers/models/unet_2d_condition.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* remove prepare_extra_step_kwargs

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix a typo in timestep embedding

* remove num_image_per_prompt

* fix fasttest

* make style + fix-copies

* fix

* fix xformer test

* fix style

* doc string

* make style

* fix-copies

* docstring for time_embedding_norm

* make style

* final finishes

* make fix-copies

* fix tests

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
YiYi Xu 2023-02-06 22:11:57 -10:00 committed by GitHub
parent 3b66cc0fc1
commit 1051ca81a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2077 additions and 97 deletions

View File

@ -145,6 +145,8 @@
title: Image-Variation
- local: api/pipelines/stable_diffusion/upscale
title: Super-Resolution
- local: api/pipelines/stable_diffusion/latent_upscale
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/pix2pix
title: InstructPix2Pix
title: Stable Diffusion

View File

@ -0,0 +1,33 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Diffusion Latent Upscaler
## StableDiffusionLatentUpscalePipeline
The Stable Diffusion Latent Upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It can be used on top of any [`StableDiffusionUpscalePipeline`] checkpoint to enhance its output image resolution by a factor of 2.
A notebook that demonstrates the original implementation can be found here:
- [Stable Diffusion Upscaler Demo](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4)
Available Checkpoints are:
- *stabilityai/latent-upscaler*: [stabilityai/sd-x2-latent-upscaler](https://huggingface.co/stabilityai/sd-x2-latent-upscaler)
[[autodoc]] StableDiffusionLatentUpscalePipeline
- all
- __call__
- enable_sequential_cpu_offload
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

View File

@ -31,6 +31,7 @@ For more details about how Stable Diffusion works and how it differs from the ba
| [StableDiffusionDepth2ImgPipeline](./depth2img) | **Experimental** *Depth-to-Image Text-Guided Generation * | | Coming soon
| [StableDiffusionImageVariationPipeline](./image_variation) | **Experimental** *Image Variation Generation * | | [🤗 Stable Diffusion Image Variations](https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations)
| [StableDiffusionUpscalePipeline](./upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix)

View File

@ -0,0 +1,297 @@
import argparse
import torch
import huggingface_hub
import k_diffusion as K
from diffusers import UNet2DConditionModel
UPSCALER_REPO = "pcuenq/k-upscaler"
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
rv = {
# norm1
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
# conv1
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
# norm2
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
# conv2
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
}
if resnet.conv_shortcut is not None:
rv.update(
{
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
}
)
return rv
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
rv = {
# norm
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
# to_q
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
# to_k
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
# to_v
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
# to_out
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
}
return rv
def cross_attn_to_diffusers_checkpoint(
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
):
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
rv = {
# norm2 (ada groupnorm)
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
f"{attention_prefix}.norm_dec.mapper.weight"
],
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
f"{attention_prefix}.norm_dec.mapper.bias"
],
# layernorm on encoder_hidden_state
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
f"{attention_prefix}.norm_enc.weight"
],
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
f"{attention_prefix}.norm_enc.bias"
],
# to_q
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
f"{attention_prefix}.q_proj.weight"
]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
f"{attention_prefix}.q_proj.bias"
],
# to_k
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
# to_v
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
# to_out
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
f"{attention_prefix}.out_proj.weight"
]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
f"{attention_prefix}.out_proj.bias"
],
}
return rv
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
block_prefix = f"{block_prefix}.{block_idx}"
diffusers_checkpoint = {}
if not hasattr(block, "attentions"):
n = 1 # resnet only
elif not block.attentions[0].add_self_attention:
n = 2 # resnet -> cross-attention
else:
n = 3 # resnet -> self-attention -> cross-attention)
for resnet_idx, resnet in enumerate(block.resnets):
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
diffusers_checkpoint.update(
resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
if hasattr(block, "attentions"):
for attention_idx, attention in enumerate(block.attentions):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
cross_attention_prefix = f"{block_prefix}.{idx }"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
diffusers_attention_index=2,
attention_prefix=cross_attention_prefix,
)
)
if attention.add_self_attention is True:
diffusers_checkpoint.update(
self_attn_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=self_attention_prefix,
)
)
return diffusers_checkpoint
def unet_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# pre-processing
diffusers_checkpoint.update(
{
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
}
)
# timestep and class embedding
diffusers_checkpoint.update(
{
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
}
)
# down_blocks
for down_block_idx, down_block in enumerate(model.down_blocks):
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
# up_blocks
for up_block_idx, up_block in enumerate(model.up_blocks):
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
# post-processing
diffusers_checkpoint.update(
{
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
}
)
return diffusers_checkpoint
def unet_model_from_original_config(original_config):
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
block_out_channels = original_config["channels"]
assert (
len(set(original_config["depths"])) == 1
), "UNet2DConditionModel currently do not support blocks with different number of layers"
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
cross_attention_dim = original_config["cross_cond_dim"]
attn1_types = []
attn2_types = []
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
if s:
a1 = "self"
a2 = "cross" if c else None
elif c:
a1 = "cross"
a2 = None
else:
a1 = None
a2 = None
attn1_types.append(a1)
attn2_types.append(a2)
unet = UNet2DConditionModel(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
mid_block_type=None,
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn="gelu",
norm_num_groups=None,
cross_attention_dim=cross_attention_dim,
attention_head_dim=64,
time_cond_proj_dim=class_labels_dim,
resnet_time_scale_shift="scale_shift",
time_embedding_type="fourier",
timestep_post_act="gelu",
conv_in_kernel=1,
conv_out_kernel=1,
)
return unet
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
orig_weights_path = huggingface_hub.hf_hub_download(
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
)
print(f"loading original model configuration from {orig_config_path}")
print(f"loading original model checkpoint from {orig_weights_path}")
print("converting to diffusers unet")
orig_config = K.config.load_config(open(orig_config_path))["model"]
model = unet_model_from_original_config(orig_config)
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
model.load_state_dict(converted_checkpoint, strict=True)
model.save_pretrained(args.dump_path)
print(f"saving converted unet model in {args.dump_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)

View File

@ -115,6 +115,7 @@ else:
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,

View File

@ -480,3 +480,38 @@ class AdaLayerNormZero(nn.Module):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaGroupNorm(nn.Module):
"""
GroupNorm layer modified to incorporate timestep embeddings.
"""
def __init__(
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.act = None
if act_fn == "swish":
self.act = lambda x: F.silu(x)
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "gelu":
self.act = nn.GELU()
self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x, emb):
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x

View File

@ -17,7 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import logging
from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available
@ -56,6 +56,7 @@ class CrossAttention(nn.Module):
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
processor: Optional["AttnProcessor"] = None,
@ -65,6 +66,7 @@ class CrossAttention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5
@ -81,6 +83,9 @@ class CrossAttention(nn.Module):
else:
self.group_norm = None
if cross_attention_norm:
self.norm_cross = nn.LayerNorm(cross_attention_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@ -224,7 +229,19 @@ class CrossAttention(nn.Module):
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length):
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
if batch_size is None:
deprecate(
"batch_size=None",
"0.0.15",
message=(
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
" `prepare_attention_mask` when preparing the attention_mask."
),
)
batch_size = 1
head_size = self.heads
if attention_mask is None:
return attention_mask
@ -238,18 +255,29 @@ class CrossAttention(nn.Module):
attention_mask = torch.concat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -305,7 +333,7 @@ class LoRACrossAttnProcessor(nn.Module):
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
@ -337,7 +365,7 @@ class CrossAttnAddedKVProcessor:
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@ -379,11 +407,15 @@ class XFormersCrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -417,7 +449,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
@ -448,13 +480,17 @@ class SlicedAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
@ -500,7 +536,7 @@ class SlicedAttnAddedKVProcessor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import numpy as np
import torch
@ -152,15 +153,32 @@ class PatchEmbed(nn.Module):
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
if out_dim is not None:
time_embed_dim_out = out_dim
@ -168,13 +186,29 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample):
if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample

View File

@ -1,9 +1,12 @@
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import AdaGroupNorm
class Upsample1D(nn.Module):
"""
@ -364,7 +367,70 @@ class FirDownsample2D(nn.Module):
return hidden_states
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class KUpsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
@ -378,12 +444,14 @@ class ResnetBlock2D(nn.Module):
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None,
output_scale_factor=1.0,
use_in_shortcut=None,
up=False,
down=False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
@ -392,40 +460,50 @@ class ResnetBlock2D(nn.Module):
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group":
self.time_emb_proj = None
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.upsample = self.downsample = None
if self.up:
@ -445,16 +523,22 @@ class ResnetBlock2D(nn.Module):
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv_shortcut = torch.nn.Conv2d(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
@ -470,12 +554,15 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states)
if temb is not None:
if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":

View File

@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock
from .attention import AdaGroupNorm, AttentionBlock
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel
@ -168,6 +170,29 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "KDownBlock2D":
return KDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif down_block_type == "KCrossAttnDownBlock2D":
return KCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
add_self_attention=True if not add_downsample else False,
)
raise ValueError(f"{down_block_type} does not exist.")
@ -317,6 +342,29 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "KUpBlock2D":
return KUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "KCrossAttnUpBlock2D":
return KCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
@ -1384,6 +1432,189 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class KDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
add_downsample=False,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=groups,
groups_out=groups_out,
eps=resnet_eps,
non_linearity=resnet_act_fn,
time_embedding_norm="ada_group",
conv_shortcut_bias=False,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
# YiYi's comments- might be able to use FirDownsample2D, look into details later
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, output_states
class KCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
cross_attention_dim: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_group_size: int = 32,
add_downsample=True,
attn_num_head_channels: int = 64,
add_self_attention: bool = False,
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
temb_channels=temb_channels,
groups=groups,
groups_out=groups_out,
eps=resnet_eps,
non_linearity=resnet_act_fn,
time_embedding_norm="ada_group",
conv_shortcut_bias=False,
)
)
attentions.append(
KAttentionBlock(
out_channels,
out_channels // attn_num_head_channels,
attn_num_head_channels,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
group_size=resnet_group_size,
)
)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
if add_downsample:
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
emb=temb,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if self.downsamplers is None:
output_states += (None,)
else:
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, output_states
class AttnUpBlock2D(nn.Module):
def __init__(
self,
@ -2198,3 +2429,321 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
hidden_states = upsampler(hidden_states, temb)
return hidden_states
class KUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 5,
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: Optional[int] = 32,
add_upsample=True,
):
super().__init__()
resnets = []
k_in_channels = 2 * out_channels
k_out_channels = in_channels
num_layers = num_layers - 1
for i in range(num_layers):
in_channels = k_in_channels if i == 0 else out_channels
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=groups,
groups_out=groups_out,
dropout=dropout,
non_linearity=resnet_act_fn,
time_embedding_norm="ada_group",
conv_shortcut_bias=False,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class KCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
attn_num_head_channels=1, # attention dim_head
cross_attention_dim: int = 768,
add_upsample: bool = True,
upcast_attention: bool = False,
):
super().__init__()
resnets = []
attentions = []
is_first_block = in_channels == out_channels == temb_channels
is_middle_block = in_channels != out_channels
add_self_attention = True if is_first_block else False
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
# in_channels, and out_channels for the block (k-unet)
k_in_channels = out_channels if is_first_block else 2 * out_channels
k_out_channels = in_channels
num_layers = num_layers - 1
for i in range(num_layers):
in_channels = k_in_channels if i == 0 else out_channels
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
if is_middle_block and (i == num_layers - 1):
conv_2d_out_channels = k_out_channels
else:
conv_2d_out_channels = None
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
conv_2d_out_channels=conv_2d_out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=groups,
groups_out=groups_out,
dropout=dropout,
non_linearity=resnet_act_fn,
time_embedding_norm="ada_group",
conv_shortcut_bias=False,
)
)
attentions.append(
KAttentionBlock(
k_out_channels if (i == num_layers - 1) else out_channels,
k_out_channels // attn_num_head_channels
if (i == num_layers - 1)
else out_channels // attn_num_head_channels,
attn_num_head_channels,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
upcast_attention=upcast_attention,
)
)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
if add_upsample:
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
cross_attention_kwargs=None,
upsample_size=None,
attention_mask=None,
):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
emb=temb,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
# can potentially later be renamed to `No-feed-forward` attention
class KAttentionBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
temb_channels: int = 768, # for ada_group_norm
add_self_attention: bool = False,
cross_attention_norm: bool = False,
group_size: int = 32,
):
super().__init__()
self.add_self_attention = add_self_attention
# 1. Self-Attn
if add_self_attention:
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
cross_attention_norm=None,
)
# 2. Cross-Attn
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_attention_norm=cross_attention_norm,
)
def _to_3d(self, hidden_states, height, weight):
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
def _to_4d(self, hidden_states, height, weight):
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
emb=None,
attention_mask=None,
cross_attention_kwargs=None,
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 1. Self-Attention
if self.add_self_attention:
norm_hidden_states = self.norm1(hidden_states, emb)
height, weight = norm_hidden_states.shape[2:]
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
**cross_attention_kwargs,
)
attn_output = self._to_4d(attn_output, height, weight)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention/None
norm_hidden_states = self.norm2(hidden_states, emb)
height, weight = norm_hidden_states.shape[2:]
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
attn_output = self._to_4d(attn_output, height, weight)
hidden_states = attn_output + hidden_states
return hidden_states

View File

@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
@ -70,9 +70,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
@ -80,6 +84,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
@ -90,6 +95,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
"""
_supports_gradient_checkpointing = True
@ -109,7 +122,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: str = "UNetMidBlock2DCrossAttn",
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
@ -117,7 +130,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
@ -127,20 +140,48 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional", # fourier, positional
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
if time_embedding_type == "fourier":
time_embed_dim = block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
@ -153,7 +194,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
@ -218,6 +258,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
@ -228,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@ -266,9 +309,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
@ -399,6 +452,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
@ -466,7 +520,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
@ -498,6 +553,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
@ -533,6 +589,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

View File

@ -52,6 +52,7 @@ else:
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)

View File

@ -632,6 +632,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
@ -639,8 +643,13 @@ class AltDiffusionPipeline(DiffusionPipeline):
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not return_dict:
return (image, has_nsfw_concept)

View File

@ -44,6 +44,7 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker

View File

@ -629,6 +629,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
@ -636,8 +640,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not return_dict:
return (image, has_nsfw_concept)

View File

@ -0,0 +1,520 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import PIL
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
r"""
Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`EulerDiscreteScheduler`].
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: EulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_length=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_out = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=True,
)
text_embeddings = text_encoder_out.hidden_states[-1]
text_pooler_out = text_encoder_out.pooler_output
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=True,
return_tensors="pt",
)
uncond_encoder_out = self.text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
uncond_embeddings = uncond_encoder_out.hidden_states[-1]
uncond_pooler_out = uncond_encoder_out.pooler_output
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
return text_embeddings, text_pooler_out
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, noise_level, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0] if image.ndim == 4 else 1
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 0,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image upscaling.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be
either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It
will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an
image representation and encoded using this pipeline's `vae` encoder.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )
>>> pipeline.to("cuda")
>>> model_id = "stabilityai/sd-x2-latent-upscaler"
>>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
>>> upscaler.to("cuda")
>>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic"
>>> generator = torch.manual_seed(33)
>>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images
>>> with torch.no_grad():
... image = pipeline.decode_latents(low_res_latents)
>>> image = pipeline.numpy_to_pil(image)[0]
>>> image.save("../images/a1.png")
>>> upscaled_image = upscaler(
... prompt=prompt,
... image=low_res_latents,
... num_inference_steps=20,
... guidance_scale=0,
... generator=generator,
... ).images[0]
>>> upscaled_image.save("../images/a2.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if guidance_scale == 0:
prompt = [""] * batch_size
# 3. Encode input prompt
text_embeddings, text_pooler_out = self._encode_prompt(
prompt, device, do_classifier_free_guidance, negative_prompt
)
# 4. Preprocess image
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
if image.shape[1] == 3:
# encode image if not in latent-space yet
image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = image[None, :] if image.ndim == 3 else image
image = torch.cat([image] * batch_multiplier)
# 5. Add noise to image (set to be 0):
# (see below notes from the author):
# "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default."
noise_level = torch.tensor([0.0], dtype=torch.float32, device=device)
noise_level = torch.cat([noise_level] * image.shape[0])
inv_noise_level = (noise_level**2 + 1) ** (-0.5)
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
image_cond = image_cond.to(text_embeddings.dtype)
noise_level_embed = torch.cat(
[
torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
],
dim=1,
)
timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
# 6. Prepare latent variables
height, width = image.shape[2:]
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height * 2, # 2x upscale
width * 2,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 9. Denoising loop
num_warmup_steps = 0
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
sigma = self.scheduler.sigmas[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1)
# preconditioning parameter based on Karras et al. (2022) (table 1)
timestep = torch.log(sigma) * 0.25
noise_pred = self.unet(
scaled_model_input,
timestep,
encoder_hidden_states=text_embeddings,
timestep_cond=timestep_condition,
).sample
# in original repo, the output contains a variance channel that's not used
noise_pred = noise_pred[:, :-1]
# apply preconditioning, based on table 1 in Karras et al. (2022)
inv_sigma = 1 / (sigma**2 + 1)
noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -9,7 +9,7 @@ from ...models import ModelMixin
from ...models.attention import CrossAttention
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import TimestepEmbedding, Timesteps
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging
@ -151,9 +151,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`.
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip
the mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
@ -161,6 +165,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
@ -171,6 +176,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
"""
_supports_gradient_checkpointing = True
@ -190,7 +203,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: str = "UNetMidBlockFlatCrossAttn",
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
@ -203,7 +216,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
@ -213,20 +226,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional", # fourier, positional
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
if time_embedding_type == "fourier":
time_embed_dim = block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
@ -239,7 +280,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
@ -304,6 +344,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
@ -314,6 +356,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@ -352,9 +395,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = LinearMultiDim(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
@ -485,6 +538,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
@ -552,7 +606,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
@ -584,6 +639,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
@ -619,6 +675,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

View File

@ -65,11 +65,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction_type (`str`, default `"epsilon"`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
interpolation_type (`str`, default `"linear"`, optional):
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
[`"linear"`, `"log_linear"`].
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@ -84,6 +86,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@ -130,7 +133,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
@ -148,7 +153,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
@ -230,7 +245,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
if self.config.prediction_type == "original_sample":
pred_original_sample = model_output
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip

View File

@ -169,6 +169,21 @@ class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@ -252,7 +252,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)

View File

@ -0,0 +1,219 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
EulerDiscreteScheduler,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionLatentUpscalePipeline
test_cpu_offload = True
@property
def dummy_image(self):
batch_size = 1
num_channels = 4
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
def get_dummy_components(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
act_fn="gelu",
attention_head_dim=8,
norm_num_groups=None,
block_out_channels=[32, 32, 64, 64],
time_cond_proj_dim=160,
conv_in_kernel=1,
conv_out_kernel=1,
cross_attention_dim=32,
down_block_types=(
"KDownBlock2D",
"KCrossAttnDownBlock2D",
"KCrossAttnDownBlock2D",
"KCrossAttnDownBlock2D",
),
in_channels=8,
mid_block_type=None,
only_cross_attention=False,
out_channels=5,
resnet_time_scale_shift="scale_shift",
time_embedding_type="fourier",
timestep_post_act="gelu",
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
)
vae = AutoencoderKL(
block_out_channels=[32, 32, 64, 64],
in_channels=3,
out_channels=3,
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
scheduler = EulerDiscreteScheduler(prediction_type="original_sample")
text_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="quick_gelu",
projection_dim=512,
)
text_encoder = CLIPTextModel(text_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": model.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": self.dummy_image.cpu(),
"generator": generator,
"num_inference_steps": 2,
"output_type": "numpy",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 256, 256, 3))
expected_slice = np.array(
[0.47222412, 0.41921633, 0.44717434, 0.46874192, 0.42588258, 0.46150726, 0.4677534, 0.45583832, 0.48579055]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(relax_max_difference=False)
@require_torch_gpu
@slow
class StableDiffusionLatentUpscalePipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_latent_upscaler_fp16(self):
generator = torch.manual_seed(33)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.to("cuda")
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
)
upscaler.to("cuda")
prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic"
low_res_latents = pipe(prompt, generator=generator, output_type="latent").images
image = upscaler(
prompt=prompt,
image=low_res_latents,
num_inference_steps=20,
guidance_scale=0,
generator=generator,
output_type="np",
).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/astronaut_1024.npy"
)
assert np.abs((expected_image - image).max()) < 1e-3
def test_latent_upscaler_fp16_image(self):
generator = torch.manual_seed(33)
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
)
upscaler.to("cuda")
prompt = "the temple of fire by Ross Tran and Gerardo Dottori, oil on canvas"
low_res_img = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/fire_temple_512.png"
)
image = upscaler(
prompt=prompt,
image=low_res_img,
num_inference_steps=20,
guidance_scale=0,
generator=generator,
output_type="np",
).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/latent-upscaler/fire_temple_1024.npy"
)
assert np.abs((expected_image - image).max()) < 1e-3