diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d80ecd88..82079d4b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -174,9 +174,7 @@ class Downsample(nn.Module): # return self.conv(x) -# RESNETS - -# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py +# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py class ResnetBlock(nn.Module): def __init__( self, @@ -187,15 +185,20 @@ class ResnetBlock(nn.Module): dropout=0.0, temb_channels=512, groups=32, + groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_nin_shortcut=None, up=False, down=False, overwrite_for_grad_tts=False, overwrite_for_ldm=False, overwrite_for_glide=False, + overwrite_for_score_vde=False, ): super().__init__() self.pre_norm = pre_norm @@ -206,6 +209,10 @@ class ResnetBlock(nn.Module): self.time_embedding_norm = time_embedding_norm self.up = up self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups if self.pre_norm: self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) @@ -219,7 +226,7 @@ class ResnetBlock(nn.Module): elif time_embedding_norm == "scale_shift": self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) + self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -230,14 +237,29 @@ class ResnetBlock(nn.Module): elif non_linearity == "silu": self.nonlinearity = nn.SiLU() - if up: - self.h_upd = Upsample(in_channels, use_conv=False, dims=2) - self.x_upd = Upsample(in_channels, use_conv=False, dims=2) - elif down: - self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + # if up: + # self.h_upd = Upsample(in_channels, use_conv=False, dims=2) + # self.x_upd = Upsample(in_channels, use_conv=False, dims=2) + # elif down: + # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - if self.in_channels != self.out_channels: + self.upsample = self.downsample = None + if self.up and kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, k=fir_kernel) + elif self.up and kernel is None: + self.upsample = Upsample(in_channels, use_conv=False, dims=2) + elif self.down and kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, k=fir_kernel) + elif self.down and kernel is None: + self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + + self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut + + self.nin_shortcut = None + if self.use_nin_shortcut: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED @@ -245,6 +267,7 @@ class ResnetBlock(nn.Module): self.overwrite_for_glide = overwrite_for_glide self.overwrite_for_grad_tts = overwrite_for_grad_tts self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide + self.overwrite_for_score_vde = overwrite_for_score_vde if self.overwrite_for_grad_tts: dim = in_channels dim_out = out_channels @@ -260,12 +283,10 @@ class ResnetBlock(nn.Module): self.res_conv = torch.nn.Identity() elif self.overwrite_for_ldm: dims = 2 - # eps = 1e-5 - # non_linearity = "silu" - # overwrite_for_ldm channels = in_channels emb_channels = temb_channels use_scale_shift_norm = False + non_linearity = "silu" self.in_layers = nn.Sequential( normalization(channels, swish=1.0), @@ -289,6 +310,40 @@ class ResnetBlock(nn.Module): self.skip_connection = nn.Identity() else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + elif self.overwrite_for_score_vde: + in_ch = in_channels + out_ch = out_channels + + eps = 1e-6 + num_groups = min(in_ch // 4, 32) + num_groups_out = min(out_ch // 4, 32) + temb_dim = temb_channels + # output_scale_factor = np.sqrt(2.0) + # non_linearity = "silu" + # use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True + + self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) + self.up = up + self.down = down + self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1) + if in_ch != out_ch or up or down: + # 1x1 convolution with DDPM initialization. + self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) + + # self.skip_rescale = skip_rescale + self.in_ch = in_ch + self.out_ch = out_ch + + # TODO(Patrick) - move to main init + self.is_overwritten = False def set_weights_grad_tts(self): self.conv1.weight.data = self.block1.block[0].weight.data @@ -328,6 +383,24 @@ class ResnetBlock(nn.Module): self.nin_shortcut.weight.data = self.skip_connection.weight.data self.nin_shortcut.bias.data = self.skip_connection.bias.data + def set_weights_score_vde(self): + self.conv1.weight.data = self.Conv_0.weight.data + self.conv1.bias.data = self.Conv_0.bias.data + self.norm1.weight.data = self.GroupNorm_0.weight.data + self.norm1.bias.data = self.GroupNorm_0.bias.data + + self.conv2.weight.data = self.Conv_1.weight.data + self.conv2.bias.data = self.Conv_1.bias.data + self.norm2.weight.data = self.GroupNorm_1.weight.data + self.norm2.bias.data = self.GroupNorm_1.bias.data + + self.temb_proj.weight.data = self.Dense_0.weight.data + self.temb_proj.bias.data = self.Dense_0.bias.data + + if self.in_channels != self.out_channels or self.up or self.down: + self.nin_shortcut.weight.data = self.Conv_2.weight.data + self.nin_shortcut.bias.data = self.Conv_2.bias.data + def forward(self, x, temb, mask=1.0): # TODO(Patrick) eventually this class should be split into multiple classes # too many if else statements @@ -337,6 +410,9 @@ class ResnetBlock(nn.Module): elif self.overwrite_for_ldm and not self.is_overwritten: self.set_weights_ldm() self.is_overwritten = True + elif self.overwrite_for_score_vde and not self.is_overwritten: + self.set_weights_score_vde() + self.is_overwritten = True h = x h = h * mask @@ -344,9 +420,12 @@ class ResnetBlock(nn.Module): h = self.norm1(h) h = self.nonlinearity(h) - if self.up or self.down: - x = self.x_upd(x) - h = self.h_upd(h) + if self.upsample is not None: + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) h = self.conv1(h) @@ -379,10 +458,10 @@ class ResnetBlock(nn.Module): h = h * mask x = x * mask - if self.in_channels != self.out_channels: + if self.nin_shortcut is not None: x = self.nin_shortcut(x) - return x + h + return (x + h) / self.output_scale_factor # TODO(Patrick) - just there to convert the weights; can delete afterward diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 7e368b87..f4c7f66c 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d +from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d def _setup_kernel(k): @@ -276,8 +276,7 @@ class NCSNpp(ModelMixin, ConfigMixin): skip_rescale=skip_rescale, continuous=continuous, ) - self.act = act = nn.SiLU() - + self.act = nn.SiLU() self.nf = nf self.num_res_blocks = num_res_blocks self.attn_resolutions = attn_resolutions @@ -333,19 +332,6 @@ class NCSNpp(ModelMixin, ConfigMixin): elif progressive_input == "residual": pyramid_downsample = functools.partial(Down_sample, use_conv=True) - ResnetBlock = functools.partial( - ResnetBlockBigGANpp, - act=act, - dropout=dropout, - fir=fir, - fir_kernel=fir_kernel, - init_scale=init_scale, - skip_rescale=skip_rescale, - temb_dim=nf * 4, - ) - - # Downsampling block - channels = num_channels if progressive_input != "none": input_pyramid_ch = channels @@ -358,7 +344,18 @@ class NCSNpp(ModelMixin, ConfigMixin): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] - modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + modules.append( + ResnetBlock( + in_channels=in_ch, + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: @@ -366,7 +363,20 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) if i_level != self.num_resolutions - 1: - modules.append(ResnetBlock(down=True, in_ch=in_ch)) + modules.append( + ResnetBlock( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + down=True, + kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine + use_nin_shortcut=True, + ) + ) if progressive_input == "input_skip": modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) @@ -380,16 +390,48 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) in_ch = hs_c[-1] - modules.append(ResnetBlock(in_ch=in_ch)) + modules.append( + ResnetBlock( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) modules.append(AttnBlock(channels=in_ch)) - modules.append(ResnetBlock(in_ch=in_ch)) + modules.append( + ResnetBlock( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) pyramid_ch = 0 # Upsampling block for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] - modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = in_ch + hs_c.pop() + modules.append( + ResnetBlock( + in_channels=in_ch, + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: @@ -421,7 +463,20 @@ class NCSNpp(ModelMixin, ConfigMixin): raise ValueError(f"{progressive} is not a valid name") if i_level != 0: - modules.append(ResnetBlock(in_ch=in_ch, up=True)) + modules.append( + ResnetBlock( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + up=True, + kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine + use_nin_shortcut=True, + ) + ) assert not hs_c