diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 41955313a..b4f03b138 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,5 +1,7 @@ import torch from packaging import version +from einops import repeat +import math from modules import devices from modules.sd_hijack_utils import CondFunc @@ -52,6 +54,54 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): return result +# Monkey patch to create timestep embed tensor on device, avoiding a block. +def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# Prevents a lot of unnecessary aten::copy_ calls +def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = x.view(b, h, w, c).permute(0, 3, 1, 2) + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + class GELUHijack(torch.nn.GELU, torch.nn.Module): def __init__(self, *args, **kwargs): torch.nn.GELU.__init__(self, *args, **kwargs) @@ -72,6 +122,10 @@ def hijack_ddpm_edit(): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)