Fix attention for Glide (#75)

This commit is contained in:
Anton Lozhkov 2022-07-04 14:55:56 +02:00 committed by GitHub
parent d9316bf8bc
commit 107986639d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 5 deletions

View File

@ -73,6 +73,8 @@ class AttentionBlock(nn.Module):
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
@ -80,9 +82,7 @@ class AttentionBlock(nn.Module):
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
@ -91,6 +91,8 @@ class AttentionBlock(nn.Module):
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.is_overwritten = False
@ -120,9 +122,12 @@ class AttentionBlock(nn.Module):
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = module.proj_out.weight.data
self.proj.bias.data = module.proj_out.bias.data
def forward(self, x, encoder_out=None):
if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
if not self.is_overwritten:
self.set_weights(self)
self.is_overwritten = True

View File

@ -133,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True,
)
# self.mid = UNetMidBlock2D
# self.mid = UNetMidBlock2D
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(