From 107986639d7d07d893cfda2492011d3b81022d46 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 4 Jul 2022 14:55:56 +0200 Subject: [PATCH] Fix attention for Glide (#75) --- src/diffusers/models/attention.py | 13 +++++++++---- src/diffusers/models/unet_grad_tts.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5f86993c..6daca7f0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -73,6 +73,8 @@ class AttentionBlock(nn.Module): self.proj = zero_module(nn.Conv1d(channels, channels, 1)) self.overwrite_qkv = overwrite_qkv + self.overwrite_linear = overwrite_linear + if overwrite_qkv: in_channels = channels self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) @@ -80,9 +82,7 @@ class AttentionBlock(nn.Module): self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - self.overwrite_linear = overwrite_linear - if self.overwrite_linear: + elif self.overwrite_linear: num_groups = min(channels // 4, 32) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) self.NIN_0 = NIN(channels, channels) @@ -91,6 +91,8 @@ class AttentionBlock(nn.Module): self.NIN_3 = NIN(channels, channels) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) + else: + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.is_overwritten = False @@ -120,9 +122,12 @@ class AttentionBlock(nn.Module): self.norm.weight.data = self.GroupNorm_0.weight.data self.norm.bias.data = self.GroupNorm_0.bias.data + else: + self.proj.weight.data = module.proj_out.weight.data + self.proj.bias.data = module.proj_out.bias.data def forward(self, x, encoder_out=None): - if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten: + if not self.is_overwritten: self.set_weights(self) self.is_overwritten = True diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 3ec51b09..357d6784 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -133,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): overwrite_for_grad_tts=True, ) -# self.mid = UNetMidBlock2D + # self.mid = UNetMidBlock2D for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append(