parent
107986639d
commit
c352faeae3
|
@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
resnet_groups=8,
|
||||
resnet_pre_norm=False,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="mish",
|
||||
attention_layer_type="linear",
|
||||
)
|
||||
|
||||
self.mid_block1 = ResnetBlock2D(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
|
@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
|
||||
# self.mid = UNetMidBlock2D
|
||||
self.mid.resnet_1 = self.mid_block1
|
||||
self.mid.attn = self.mid_attn
|
||||
self.mid.resnet_2 = self.mid_block2
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
|
@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
x = self.mid_block1(x, t, mask_mid)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t, mask_mid)
|
||||
|
||||
x = self.mid(x, t, mask=mask_mid)
|
||||
|
||||
for resnet1, resnet2, attn, upsample in self.ups:
|
||||
mask_up = masks.pop()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .attention import AttentionBlock, LinearAttention, SpatialTransformer
|
||||
from .resnet import ResnetBlock2D
|
||||
|
||||
|
||||
|
@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
|
|||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_layer_type: str = "self",
|
||||
attn_num_heads=1,
|
||||
attn_num_head_channels=None,
|
||||
|
@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
|
|||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
|
||||
if attention_layer_type == "self":
|
||||
|
@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
|
|||
rescale_output_factor=output_scale_factor,
|
||||
)
|
||||
elif attention_layer_type == "spatial":
|
||||
self.attn = (
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
attn_num_heads,
|
||||
attn_num_head_channels,
|
||||
depth=attn_depth,
|
||||
context_dim=attn_encoder_channels,
|
||||
),
|
||||
self.attn = SpatialTransformer(
|
||||
attn_num_heads,
|
||||
attn_num_head_channels,
|
||||
depth=attn_depth,
|
||||
context_dim=attn_encoder_channels,
|
||||
)
|
||||
elif attention_layer_type == "linear":
|
||||
self.attn = LinearAttention(in_channels)
|
||||
|
||||
self.resnet_2 = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
|
@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
|
|||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete all of the following code
|
||||
|
@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
|
|||
eps=resnet_eps,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
|
||||
if not self.is_overwritten and self.overwrite_unet:
|
||||
self.resnet_1 = self.block_1
|
||||
self.attn = self.attn_1
|
||||
self.resnet_2 = self.block_2
|
||||
self.is_overwritten = True
|
||||
|
||||
hidden_states = self.resnet_1(hidden_states, temb)
|
||||
hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
|
||||
|
||||
if encoder_states is None:
|
||||
hidden_states = self.attn(hidden_states)
|
||||
else:
|
||||
hidden_states = self.attn(hidden_states, encoder_states)
|
||||
|
||||
hidden_states = self.resnet_2(hidden_states, temb)
|
||||
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
|
||||
|
||||
return hidden_states
|
||||
|
|
Loading…
Reference in New Issue