From 94566e6dd8b018726b215f70e818589ac9815830 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Jul 2022 11:52:22 +0200 Subject: [PATCH] update mid block (#70) * update mid block * finish mid block --- src/diffusers/models/attention.py | 268 ++++++++++-------- src/diffusers/models/unet.py | 18 +- src/diffusers/models/unet_glide.py | 26 +- src/diffusers/models/unet_grad_tts.py | 2 + src/diffusers/models/unet_ldm.py | 35 ++- src/diffusers/models/unet_new.py | 128 +++++++++ .../models/unet_sde_score_estimation.py | 31 +- 7 files changed, 361 insertions(+), 147 deletions(-) create mode 100644 src/diffusers/models/unet_new.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1d7e85e3..7e8af27e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1,6 +1,8 @@ import math +from inspect import isfunction import torch +import torch.nn.functional as F from torch import nn @@ -43,18 +45,16 @@ class AttentionBlock(nn.Module): self, channels, num_heads=1, - num_head_channels=-1, + num_head_channels=None, num_groups=32, - use_checkpoint=False, encoder_channels=None, - use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? overwrite_qkv=False, overwrite_linear=False, rescale_output_factor=1.0, ): super().__init__() self.channels = channels - if num_head_channels == -1: + if num_head_channels is None: self.num_heads = num_heads else: assert ( @@ -62,7 +62,6 @@ class AttentionBlock(nn.Module): ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) self.qkv = nn.Conv1d(channels, channels * 3, 1) self.n_heads = self.num_heads @@ -160,115 +159,135 @@ class AttentionBlock(nn.Module): return result -# unet_score_estimation.py -# class AttnBlockpp(nn.Module): -# """Channel-wise self-attention block. Modified from DDPM.""" -# -# def __init__( -# self, -# channels, -# skip_rescale=False, -# init_scale=0.0, -# num_heads=1, -# num_head_channels=-1, -# use_checkpoint=False, -# encoder_channels=None, -# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? -# overwrite_qkv=False, -# overwrite_from_grad_tts=False, -# ): -# super().__init__() -# num_groups = min(channels // 4, 32) -# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) -# self.NIN_0 = NIN(channels, channels) -# self.NIN_1 = NIN(channels, channels) -# self.NIN_2 = NIN(channels, channels) -# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) -# self.skip_rescale = skip_rescale -# -# self.channels = channels -# if num_head_channels == -1: -# self.num_heads = num_heads -# else: -# assert ( -# channels % num_head_channels == 0 -# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" -# self.num_heads = channels // num_head_channels -# -# self.use_checkpoint = use_checkpoint -# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) -# self.qkv = nn.Conv1d(channels, channels * 3, 1) -# self.n_heads = self.num_heads -# -# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) -# -# self.is_weight_set = False -# -# def set_weights(self): -# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] -# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) -# -# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] -# self.proj_out.bias.data = self.NIN_3.b.data -# -# self.norm.weight.data = self.GroupNorm_0.weight.data -# self.norm.bias.data = self.GroupNorm_0.bias.data -# -# def forward(self, x): -# if not self.is_weight_set: -# self.set_weights() -# self.is_weight_set = True -# -# B, C, H, W = x.shape -# h = self.GroupNorm_0(x) -# q = self.NIN_0(h) -# k = self.NIN_1(h) -# v = self.NIN_2(h) -# -# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) -# w = torch.reshape(w, (B, H, W, H * W)) -# w = F.softmax(w, dim=-1) -# w = torch.reshape(w, (B, H, W, H, W)) -# h = torch.einsum("bhwij,bcij->bchw", w, v) -# h = self.NIN_3(h) -# -# if not self.skip_rescale: -# result = x + h -# else: -# result = (x + h) / np.sqrt(2.0) -# -# result = self.forward_2(x) -# -# return result -# -# def forward_2(self, x, encoder_out=None): -# b, c, *spatial = x.shape -# hid_states = self.norm(x).view(b, c, -1) -# -# qkv = self.qkv(hid_states) -# bs, width, length = qkv.shape -# assert width % (3 * self.n_heads) == 0 -# ch = width // (3 * self.n_heads) -# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) -# -# if encoder_out is not None: -# encoder_kv = self.encoder_kv(encoder_out) -# assert encoder_kv.shape[1] == self.n_heads * ch * 2 -# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) -# k = torch.cat([ek, k], dim=-1) -# v = torch.cat([ev, v], dim=-1) -# -# scale = 1 / math.sqrt(math.sqrt(ch)) -# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards -# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) -# -# a = torch.einsum("bts,bcs->bct", weight, v) -# h = a.reshape(bs, -1, length) -# -# h = self.proj_out(h) -# h = h.reshape(b, c, *spatial) -# -# return (x + h) / np.sqrt(2.0) +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image + """ + + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + for block in self.transformer_blocks: + x = block(x, context=context) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + mask = mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return self.to_out(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) # TODO(Patrick) - this can and should be removed @@ -287,3 +306,24 @@ class NIN(nn.Module): super().__init__() self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True) self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index aebf9b61..563f4f20 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .unet_new import UNetMidBlock2D def nonlinearity(x): @@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin): self.down.append(down) # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) - self.mid.block_2 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + self.mid = UNetMidBlock2D( + in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True ) # upsampling @@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin): hs.append(self.down[i_level].downsample(hs[-1])) # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h = self.mid(hs[-1], temb) + # h = self.mid.block_1(h, temb) + # h = self.mid.attn_1(h) + # h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 960d8416..7dca03b6 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .unet_new import UNetMidBlock2D def convert_module_to_f16(l): @@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): layers.append( AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=transformer_dim, @@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ds *= 2 self._feature_size += ch + self.mid = UNetMidBlock2D( + in_channels=ch, + dropout=dropout, + temb_channels=time_embed_dim, + resnet_eps=1e-5, + resnet_act_fn="silu", + resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default", + attn_num_heads=num_heads, + attn_num_head_channels=num_head_channels, + attn_encoder_channels=transformer_dim, + ) + + # TODO(Patrick) - delete after weight conversion + # init to be able to overwrite `self.mid` self.middle_block = TimestepEmbedSequential( ResnetBlock2D( in_channels=ch, @@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ), AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=transformer_dim, @@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin): overwrite_for_glide=True, ), ) + self.mid.resnet_1 = self.middle_block[0] + self.mid.attn = self.middle_block[1] + self.mid.resnet_2 = self.middle_block[2] + self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): layers.append( AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=num_head_channels, encoder_channels=transformer_dim, @@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): for module in self.input_blocks: h = module(h, emb) hs.append(h) - h = self.middle_block(h, emb) + h = self.mid(h, emb) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) @@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel): for module in self.input_blocks: h = module(h, emb, transformer_out) hs.append(h) - h = self.middle_block(h, emb, transformer_out) + h = self.mid(h, emb, transformer_out) for module in self.output_blocks: other = hs.pop() h = torch.cat([h, other], dim=1) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 32d36399..3ec51b09 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): overwrite_for_grad_tts=True, ) +# self.mid = UNetMidBlock2D + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append( torch.nn.ModuleList( diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 1589b75b..a46a4c18 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .unet_new import UNetMidBlock2D # from .resnet import ResBlock @@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin): conv_resample=conv_resample, dims=dims, num_classes=num_classes, - use_checkpoint=use_checkpoint, use_fp16=use_fp16, num_heads=num_heads, num_head_channels=num_head_channels, num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, - use_new_attention_order=use_new_attention_order, use_spatial_transformer=use_spatial_transformer, transformer_depth=transformer_depth, context_dim=context_dim, @@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes - self.use_checkpoint = use_checkpoint self.dtype_ = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels @@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin): layers.append( AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( @@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin): if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if dim_head < 0: + dim_head = None + self.mid = UNetMidBlock2D( + in_channels=ch, + dropout=dropout, + temb_channels=time_embed_dim, + resnet_eps=1e-5, + resnet_act_fn="silu", + resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default", + attention_layer_type="self" if not use_spatial_transformer else "spatial", + attn_num_heads=num_heads, + attn_num_head_channels=dim_head, + attn_depth=transformer_depth, + attn_encoder_channels=context_dim, + ) + + # TODO(Patrick) - delete after weight conversion + # init to be able to overwrite `self.mid` self.middle_block = TimestepEmbedSequential( ResnetBlock2D( in_channels=ch, @@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ), AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), @@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin): overwrite_for_ldm=True, ), ) + self.mid.resnet_1 = self.middle_block[0] + self.mid.attn = self.middle_block[1] + self.mid.resnet_2 = self.middle_block[2] + self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin): layers.append( AttentionBlock( ch, - use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( @@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): for module in self.input_blocks: h = module(h, emb, context) hs.append(h) - h = self.middle_block(h, emb, context) + h = self.mid(h, emb, context) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context) diff --git a/src/diffusers/models/unet_new.py b/src/diffusers/models/unet_new.py new file mode 100644 index 00000000..066adb6a --- /dev/null +++ b/src/diffusers/models/unet_new.py @@ -0,0 +1,128 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. +from torch import nn + +from .attention import AttentionBlock, SpatialTransformer +from .resnet import ResnetBlock2D + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attention_layer_type: str = "self", + attn_num_heads=1, + attn_num_head_channels=None, + attn_encoder_channels=None, + attn_dim_head=None, + attn_depth=None, + output_scale_factor=1.0, + overwrite_qkv=False, + overwrite_unet=False, + ): + super().__init__() + + self.resnet_1 = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + + if attention_layer_type == "self": + self.attn = AttentionBlock( + in_channels, + num_heads=attn_num_heads, + num_head_channels=attn_num_head_channels, + encoder_channels=attn_encoder_channels, + overwrite_qkv=overwrite_qkv, + rescale_output_factor=output_scale_factor, + ) + elif attention_layer_type == "spatial": + self.attn = ( + SpatialTransformer( + in_channels, + attn_num_heads, + attn_num_head_channels, + depth=attn_depth, + context_dim=attn_encoder_channels, + ), + ) + + self.resnet_2 = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + + # TODO(Patrick) - delete all of the following code + self.is_overwritten = False + self.overwrite_unet = overwrite_unet + if self.overwrite_unet: + block_in = in_channels + self.temb_ch = temb_channels + self.block_1 = ResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + eps=resnet_eps, + ) + self.attn_1 = AttentionBlock( + block_in, + num_heads=attn_num_heads, + num_head_channels=attn_num_head_channels, + encoder_channels=attn_encoder_channels, + overwrite_qkv=True, + ) + self.block_2 = ResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + eps=resnet_eps, + ) + + def forward(self, hidden_states, temb=None, encoder_states=None): + if not self.is_overwritten and self.overwrite_unet: + self.resnet_1 = self.block_1 + self.attn = self.attn_1 + self.resnet_2 = self.block_2 + self.is_overwritten = True + + hidden_states = self.resnet_1(hidden_states, temb) + + if encoder_states is None: + hidden_states = self.attn(hidden_states) + else: + hidden_states = self.attn(hidden_states, encoder_states) + + hidden_states = self.resnet_2(hidden_states, temb) + return hidden_states diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index bc54199d..cdf6c611 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D +from .unet_new import UNetMidBlock2D class Combine(nn.Module): @@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) + # mid + self.mid = UNetMidBlock2D( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=math.sqrt(2.0), + resnet_act_fn="silu", + resnet_groups=min(in_ch // 4, 32), + dropout=dropout, + ) + in_ch = hs_c[-1] modules.append( ResnetBlock2D( @@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin): overwrite_for_score_vde=True, ) ) + self.mid.resnet_1 = modules[len(modules) - 3] + self.mid.attn = modules[len(modules) - 2] + self.mid.resnet_2 = modules[len(modules) - 1] pyramid_ch = 0 # Upsampling block @@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin): hs.append(h) - h = hs[-1] - h = modules[m_idx](h, temb) - m_idx += 1 - h = modules[m_idx](h) - m_idx += 1 - h = modules[m_idx](h, temb) - m_idx += 1 + # h = hs[-1] + # h = modules[m_idx](h, temb) + # m_idx += 1 + # h = modules[m_idx](h) + # m_idx += 1 + # h = modules[m_idx](h, temb) + # m_idx += 1 + + h = self.mid(h, temb) + m_idx += 3 pyramid = None