This commit is contained in:
Patrick von Platen 2022-06-29 14:35:18 +00:00
parent 466214d2d6
commit c174bcf4bf
2 changed files with 8 additions and 183 deletions

View File

@ -340,49 +340,7 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h
# unet.py
class OLD_ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
super().__init__()
@ -429,11 +387,6 @@ class ResnetBlock(nn.Module):
else:
self.res_conv = torch.nn.Identity()
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
@ -453,11 +406,6 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None):
if not self.pre_norm:
temp = mask
mask = temb
temb = temp
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
@ -500,130 +448,7 @@ class ResnetBlock(nn.Module):
return x + h
# unet_grad_tts.py
class ResnetBlockGradTTS(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8, eps=1e-6, overwrite=True, conv_shortcut=False, pre_norm=True):
super(ResnetBlockGradTTS, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
self.overwrite = overwrite
if self.overwrite:
in_channels = dim
out_channels = dim_out
temb_channels = time_emb_dim
# To set via init
self.pre_norm = False
eps = 1e-5
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
dropout = 0.0
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nonlinearity = Mish()
self.is_overwritten = False
def set_weights(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
output = self.forward_2(x, time_emb, mask=mask)
return output
def forward_2(self, x, temb, mask=None):
if not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
if mask is None:
mask = torch.ones_like(x)
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()

View File

@ -135,8 +135,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = [mask]
for resnet1, resnet2, attn, downsample in self.downs:
mask_down = masks[-1]
x = resnet1(x, mask_down, t)
x = resnet2(x, mask_down, t)
x = resnet1(x, t, mask_down)
x = resnet2(x, t, mask_down)
x = attn(x)
hiddens.append(x)
x = downsample(x * mask_down)
@ -144,15 +144,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = masks[:-1]
mask_mid = masks[-1]
x = self.mid_block1(x, mask_mid, t)
x = self.mid_block1(x, t, mask_mid)
x = self.mid_attn(x)
x = self.mid_block2(x, mask_mid, t)
x = self.mid_block2(x, t, mask_mid)
for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop()
x = torch.cat((x, hiddens.pop()), dim=1)
x = resnet1(x, mask_up, t)
x = resnet2(x, mask_up, t)
x = resnet1(x, t, mask_up)
x = resnet2(x, t, mask_up)
x = attn(x)
x = upsample(x * mask_up)