parent
107986639d
commit
c352faeae3
|
@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
|
||||||
from .attention import LinearAttention
|
from .attention import LinearAttention
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
from .unet_new import UNetMidBlock2D
|
||||||
|
|
||||||
|
|
||||||
class Mish(torch.nn.Module):
|
class Mish(torch.nn.Module):
|
||||||
|
@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
|
self.mid = UNetMidBlock2D(
|
||||||
|
in_channels=mid_dim,
|
||||||
|
temb_channels=dim,
|
||||||
|
resnet_groups=8,
|
||||||
|
resnet_pre_norm=False,
|
||||||
|
resnet_eps=1e-5,
|
||||||
|
resnet_act_fn="mish",
|
||||||
|
attention_layer_type="linear",
|
||||||
|
)
|
||||||
|
|
||||||
self.mid_block1 = ResnetBlock2D(
|
self.mid_block1 = ResnetBlock2D(
|
||||||
in_channels=mid_dim,
|
in_channels=mid_dim,
|
||||||
out_channels=mid_dim,
|
out_channels=mid_dim,
|
||||||
|
@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
non_linearity="mish",
|
non_linearity="mish",
|
||||||
overwrite_for_grad_tts=True,
|
overwrite_for_grad_tts=True,
|
||||||
)
|
)
|
||||||
|
self.mid.resnet_1 = self.mid_block1
|
||||||
# self.mid = UNetMidBlock2D
|
self.mid.attn = self.mid_attn
|
||||||
|
self.mid.resnet_2 = self.mid_block2
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
|
@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
|
|
||||||
masks = masks[:-1]
|
masks = masks[:-1]
|
||||||
mask_mid = masks[-1]
|
mask_mid = masks[-1]
|
||||||
x = self.mid_block1(x, t, mask_mid)
|
|
||||||
x = self.mid_attn(x)
|
x = self.mid(x, t, mask=mask_mid)
|
||||||
x = self.mid_block2(x, t, mask_mid)
|
|
||||||
|
|
||||||
for resnet1, resnet2, attn, upsample in self.ups:
|
for resnet1, resnet2, attn, upsample in self.ups:
|
||||||
mask_up = masks.pop()
|
mask_up = masks.pop()
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .attention import AttentionBlock, SpatialTransformer
|
from .attention import AttentionBlock, LinearAttention, SpatialTransformer
|
||||||
from .resnet import ResnetBlock2D
|
from .resnet import ResnetBlock2D
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
temb_channels: int,
|
temb_channels: int,
|
||||||
dropout: float,
|
dropout: float = 0.0,
|
||||||
resnet_eps: float = 1e-6,
|
resnet_eps: float = 1e-6,
|
||||||
resnet_time_scale_shift: str = "default",
|
resnet_time_scale_shift: str = "default",
|
||||||
resnet_act_fn: str = "swish",
|
resnet_act_fn: str = "swish",
|
||||||
resnet_groups: int = 32,
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
attention_layer_type: str = "self",
|
attention_layer_type: str = "self",
|
||||||
attn_num_heads=1,
|
attn_num_heads=1,
|
||||||
attn_num_head_channels=None,
|
attn_num_head_channels=None,
|
||||||
|
@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
|
||||||
time_embedding_norm=resnet_time_scale_shift,
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_layer_type == "self":
|
if attention_layer_type == "self":
|
||||||
|
@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
|
||||||
rescale_output_factor=output_scale_factor,
|
rescale_output_factor=output_scale_factor,
|
||||||
)
|
)
|
||||||
elif attention_layer_type == "spatial":
|
elif attention_layer_type == "spatial":
|
||||||
self.attn = (
|
self.attn = SpatialTransformer(
|
||||||
SpatialTransformer(
|
attn_num_heads,
|
||||||
in_channels,
|
attn_num_head_channels,
|
||||||
attn_num_heads,
|
depth=attn_depth,
|
||||||
attn_num_head_channels,
|
context_dim=attn_encoder_channels,
|
||||||
depth=attn_depth,
|
|
||||||
context_dim=attn_encoder_channels,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
elif attention_layer_type == "linear":
|
||||||
|
self.attn = LinearAttention(in_channels)
|
||||||
|
|
||||||
self.resnet_2 = ResnetBlock2D(
|
self.resnet_2 = ResnetBlock2D(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
|
@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
|
||||||
time_embedding_norm=resnet_time_scale_shift,
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(Patrick) - delete all of the following code
|
# TODO(Patrick) - delete all of the following code
|
||||||
|
@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
|
||||||
eps=resnet_eps,
|
eps=resnet_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
|
||||||
if not self.is_overwritten and self.overwrite_unet:
|
if not self.is_overwritten and self.overwrite_unet:
|
||||||
self.resnet_1 = self.block_1
|
self.resnet_1 = self.block_1
|
||||||
self.attn = self.attn_1
|
self.attn = self.attn_1
|
||||||
self.resnet_2 = self.block_2
|
self.resnet_2 = self.block_2
|
||||||
self.is_overwritten = True
|
self.is_overwritten = True
|
||||||
|
|
||||||
hidden_states = self.resnet_1(hidden_states, temb)
|
hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
|
||||||
|
|
||||||
if encoder_states is None:
|
if encoder_states is None:
|
||||||
hidden_states = self.attn(hidden_states)
|
hidden_states = self.attn(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.attn(hidden_states, encoder_states)
|
hidden_states = self.attn(hidden_states, encoder_states)
|
||||||
|
|
||||||
hidden_states = self.resnet_2(hidden_states, temb)
|
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
Loading…
Reference in New Issue