Add MidBlock to Grad-TTS (#74)

Finish
This commit is contained in:
Patrick von Platen 2022-07-04 15:06:00 +02:00 committed by GitHub
parent 107986639d
commit c352faeae3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 18 deletions

View File

@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid = UNetMidBlock2D(
in_channels=mid_dim,
temb_channels=dim,
resnet_groups=8,
resnet_pre_norm=False,
resnet_eps=1e-5,
resnet_act_fn="mish",
attention_layer_type="linear",
)
self.mid_block1 = ResnetBlock2D( self.mid_block1 = ResnetBlock2D(
in_channels=mid_dim, in_channels=mid_dim,
out_channels=mid_dim, out_channels=mid_dim,
@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish", non_linearity="mish",
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
) )
self.mid.resnet_1 = self.mid_block1
# self.mid = UNetMidBlock2D self.mid.attn = self.mid_attn
self.mid.resnet_2 = self.mid_block2
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = masks[:-1] masks = masks[:-1]
mask_mid = masks[-1] mask_mid = masks[-1]
x = self.mid_block1(x, t, mask_mid)
x = self.mid_attn(x) x = self.mid(x, t, mask=mask_mid)
x = self.mid_block2(x, t, mask_mid)
for resnet1, resnet2, attn, upsample in self.ups: for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop() mask_up = masks.pop()

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from torch import nn from torch import nn
from .attention import AttentionBlock, SpatialTransformer from .attention import AttentionBlock, LinearAttention, SpatialTransformer
from .resnet import ResnetBlock2D from .resnet import ResnetBlock2D
@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
self, self,
in_channels: int, in_channels: int,
temb_channels: int, temb_channels: int,
dropout: float, dropout: float = 0.0,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self", attention_layer_type: str = "self",
attn_num_heads=1, attn_num_heads=1,
attn_num_head_channels=None, attn_num_head_channels=None,
@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
) )
if attention_layer_type == "self": if attention_layer_type == "self":
@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
) )
elif attention_layer_type == "spatial": elif attention_layer_type == "spatial":
self.attn = ( self.attn = SpatialTransformer(
SpatialTransformer( attn_num_heads,
in_channels, attn_num_head_channels,
attn_num_heads, depth=attn_depth,
attn_num_head_channels, context_dim=attn_encoder_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
),
) )
elif attention_layer_type == "linear":
self.attn = LinearAttention(in_channels)
self.resnet_2 = ResnetBlock2D( self.resnet_2 = ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
) )
# TODO(Patrick) - delete all of the following code # TODO(Patrick) - delete all of the following code
@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
eps=resnet_eps, eps=resnet_eps,
) )
def forward(self, hidden_states, temb=None, encoder_states=None): def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
if not self.is_overwritten and self.overwrite_unet: if not self.is_overwritten and self.overwrite_unet:
self.resnet_1 = self.block_1 self.resnet_1 = self.block_1
self.attn = self.attn_1 self.attn = self.attn_1
self.resnet_2 = self.block_2 self.resnet_2 = self.block_2
self.is_overwritten = True self.is_overwritten = True
hidden_states = self.resnet_1(hidden_states, temb) hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
if encoder_states is None: if encoder_states is None:
hidden_states = self.attn(hidden_states) hidden_states = self.attn(hidden_states)
else: else:
hidden_states = self.attn(hidden_states, encoder_states) hidden_states = self.attn(hidden_states, encoder_states)
hidden_states = self.resnet_2(hidden_states, temb) hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
return hidden_states return hidden_states