Add MidBlock to Grad-TTS (#74)

Finish
This commit is contained in:
Patrick von Platen 2022-07-04 15:06:00 +02:00 committed by GitHub
parent 107986639d
commit c352faeae3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 18 deletions

View File

@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
class Mish(torch.nn.Module):
@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
self.mid = UNetMidBlock2D(
in_channels=mid_dim,
temb_channels=dim,
resnet_groups=8,
resnet_pre_norm=False,
resnet_eps=1e-5,
resnet_act_fn="mish",
attention_layer_type="linear",
)
self.mid_block1 = ResnetBlock2D(
in_channels=mid_dim,
out_channels=mid_dim,
@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish",
overwrite_for_grad_tts=True,
)
# self.mid = UNetMidBlock2D
self.mid.resnet_1 = self.mid_block1
self.mid.attn = self.mid_attn
self.mid.resnet_2 = self.mid_block2
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = masks[:-1]
mask_mid = masks[-1]
x = self.mid_block1(x, t, mask_mid)
x = self.mid_attn(x)
x = self.mid_block2(x, t, mask_mid)
x = self.mid(x, t, mask=mask_mid)
for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop()

View File

@ -14,7 +14,7 @@
# limitations under the License.
from torch import nn
from .attention import AttentionBlock, SpatialTransformer
from .attention import AttentionBlock, LinearAttention, SpatialTransformer
from .resnet import ResnetBlock2D
@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
self,
in_channels: int,
temb_channels: int,
dropout: float,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
attn_num_heads=1,
attn_num_head_channels=None,
@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
if attention_layer_type == "self":
@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
rescale_output_factor=output_scale_factor,
)
elif attention_layer_type == "spatial":
self.attn = (
SpatialTransformer(
in_channels,
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
),
self.attn = SpatialTransformer(
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
)
elif attention_layer_type == "linear":
self.attn = LinearAttention(in_channels)
self.resnet_2 = ResnetBlock2D(
in_channels=in_channels,
@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
# TODO(Patrick) - delete all of the following code
@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
eps=resnet_eps,
)
def forward(self, hidden_states, temb=None, encoder_states=None):
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
if not self.is_overwritten and self.overwrite_unet:
self.resnet_1 = self.block_1
self.attn = self.attn_1
self.resnet_2 = self.block_2
self.is_overwritten = True
hidden_states = self.resnet_1(hidden_states, temb)
hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
if encoder_states is None:
hidden_states = self.attn(hidden_states)
else:
hidden_states = self.attn(hidden_states, encoder_states)
hidden_states = self.resnet_2(hidden_states, temb)
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
return hidden_states