Merge pull request #44 from huggingface/unify_resnet
Unify resnet [GradTTS & Unet.py]
This commit is contained in:
commit
eb90d3be13
|
@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
def Normalize(in_channels, num_groups=32, eps=1e-6):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
|
|
||||||
def nonlinearity(x, swish=1.0):
|
def nonlinearity(x, swish=1.0):
|
||||||
|
@ -166,8 +166,8 @@ class Downsample(nn.Module):
|
||||||
#
|
#
|
||||||
# class GlideUpsample(nn.Module):
|
# class GlideUpsample(nn.Module):
|
||||||
# """
|
# """
|
||||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
|
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
|
||||||
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||||
#
|
#
|
||||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||||
|
@ -192,8 +192,8 @@ class Downsample(nn.Module):
|
||||||
#
|
#
|
||||||
# class LDMUpsample(nn.Module):
|
# class LDMUpsample(nn.Module):
|
||||||
# """
|
# """
|
||||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
|
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
|
||||||
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||||
#
|
#
|
||||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||||
|
@ -340,40 +340,118 @@ class ResBlock(TimestepBlock):
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
# unet.py
|
# unet.py and unet_grad_tts.py
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout=0.0,
|
||||||
|
temb_channels=512,
|
||||||
|
groups=32,
|
||||||
|
pre_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
non_linearity="swish",
|
||||||
|
overwrite_for_grad_tts=False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pre_norm = pre_norm
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.use_conv_shortcut = conv_shortcut
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
self.norm1 = Normalize(in_channels)
|
if self.pre_norm:
|
||||||
|
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||||
|
else:
|
||||||
|
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||||
|
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if non_linearity == "swish":
|
||||||
|
self.nonlinearity = nonlinearity
|
||||||
|
elif non_linearity == "mish":
|
||||||
|
self.nonlinearity = Mish()
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
self.is_overwritten = False
|
||||||
|
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||||
|
if self.overwrite_for_grad_tts:
|
||||||
|
dim = in_channels
|
||||||
|
dim_out = out_channels
|
||||||
|
time_emb_dim = temb_channels
|
||||||
|
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||||
|
self.pre_norm = pre_norm
|
||||||
|
|
||||||
|
self.block1 = Block(dim, dim_out, groups=groups)
|
||||||
|
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||||
|
if dim != dim_out:
|
||||||
|
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
||||||
|
else:
|
||||||
|
self.res_conv = torch.nn.Identity()
|
||||||
|
|
||||||
|
def set_weights_grad_tts(self):
|
||||||
|
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||||
|
self.conv1.bias.data = self.block1.block[0].bias.data
|
||||||
|
self.norm1.weight.data = self.block1.block[1].weight.data
|
||||||
|
self.norm1.bias.data = self.block1.block[1].bias.data
|
||||||
|
|
||||||
|
self.conv2.weight.data = self.block2.block[0].weight.data
|
||||||
|
self.conv2.bias.data = self.block2.block[0].bias.data
|
||||||
|
self.norm2.weight.data = self.block2.block[1].weight.data
|
||||||
|
self.norm2.bias.data = self.block2.block[1].bias.data
|
||||||
|
|
||||||
|
self.temb_proj.weight.data = self.mlp[1].weight.data
|
||||||
|
self.temb_proj.bias.data = self.mlp[1].bias.data
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut.weight.data = self.res_conv.weight.data
|
||||||
|
self.nin_shortcut.bias.data = self.res_conv.bias.data
|
||||||
|
|
||||||
|
def forward(self, x, temb, mask=None):
|
||||||
|
if self.overwrite_for_grad_tts and not self.is_overwritten:
|
||||||
|
self.set_weights_grad_tts()
|
||||||
|
self.is_overwritten = True
|
||||||
|
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = h * mask if mask is not None else h
|
||||||
h = nonlinearity(h)
|
if self.pre_norm:
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = self.nonlinearity(h)
|
||||||
|
|
||||||
h = self.conv1(h)
|
h = self.conv1(h)
|
||||||
|
|
||||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
if not self.pre_norm:
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = self.nonlinearity(h)
|
||||||
|
h = h * mask if mask is not None else h
|
||||||
|
|
||||||
|
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||||
|
|
||||||
|
h = h * mask if mask is not None else h
|
||||||
|
if self.pre_norm:
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = self.nonlinearity(h)
|
||||||
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.dropout(h)
|
h = self.dropout(h)
|
||||||
h = self.conv2(h)
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if not self.pre_norm:
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = self.nonlinearity(h)
|
||||||
|
h = h * mask if mask is not None else h
|
||||||
|
|
||||||
|
x = x * mask if mask is not None else x
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
x = self.conv_shortcut(x)
|
x = self.conv_shortcut(x)
|
||||||
|
@ -383,58 +461,17 @@ class ResnetBlock(nn.Module):
|
||||||
return x + h
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
# unet_grad_tts.py
|
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
||||||
class ResnetBlockGradTTS(torch.nn.Module):
|
class Block(torch.nn.Module):
|
||||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
def __init__(self, dim, dim_out, groups=8):
|
||||||
super(ResnetBlockGradTTS, self).__init__()
|
super(Block, self).__init__()
|
||||||
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
self.block = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||||
self.block1 = Block(dim, dim_out, groups=groups)
|
|
||||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
|
||||||
if dim != dim_out:
|
|
||||||
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
|
||||||
else:
|
|
||||||
self.res_conv = torch.nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, mask, time_emb):
|
|
||||||
h = self.block1(x, mask)
|
|
||||||
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
h = self.block2(h, mask)
|
|
||||||
output = h + self.res_conv(x * mask)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
# unet_rl.py
|
|
||||||
class ResidualTemporalBlock(nn.Module):
|
|
||||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
|
||||||
Conv1dBlock(out_channels, out_channels, kernel_size),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.time_mlp = nn.Sequential(
|
def forward(self, x, mask):
|
||||||
nn.Mish(),
|
output = self.block(x * mask)
|
||||||
nn.Linear(embed_dim, out_channels),
|
return output * mask
|
||||||
RearrangeDim(),
|
|
||||||
# Rearrange("batch t -> batch t 1"),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.residual_conv = (
|
|
||||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, t):
|
|
||||||
"""
|
|
||||||
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
|
|
||||||
out_channels x horizon ]
|
|
||||||
"""
|
|
||||||
out = self.blocks[0](x) + self.time_mlp(t)
|
|
||||||
out = self.blocks[1](out)
|
|
||||||
return out + self.residual_conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
# unet_score_estimation.py
|
# unet_score_estimation.py
|
||||||
|
@ -570,6 +607,39 @@ class ResnetBlockDDPMpp(nn.Module):
|
||||||
return (x + h) / np.sqrt(2.0)
|
return (x + h) / np.sqrt(2.0)
|
||||||
|
|
||||||
|
|
||||||
|
# unet_rl.py
|
||||||
|
class ResidualTemporalBlock(nn.Module):
|
||||||
|
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||||||
|
Conv1dBlock(out_channels, out_channels, kernel_size),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mlp = nn.Sequential(
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(embed_dim, out_channels),
|
||||||
|
RearrangeDim(),
|
||||||
|
# Rearrange("batch t -> batch t 1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.residual_conv = (
|
||||||
|
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
"""
|
||||||
|
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
|
||||||
|
out_channels x horizon ]
|
||||||
|
"""
|
||||||
|
out = self.blocks[0](x) + self.time_mlp(t)
|
||||||
|
out = self.blocks[1](out)
|
||||||
|
return out + self.residual_conv(x)
|
||||||
|
|
||||||
|
|
||||||
# HELPER Modules
|
# HELPER Modules
|
||||||
|
|
||||||
|
|
||||||
|
@ -617,18 +687,6 @@ class Mish(torch.nn.Module):
|
||||||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||||
|
|
||||||
|
|
||||||
class Block(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_out, groups=8):
|
|
||||||
super(Block, self).__init__()
|
|
||||||
self.block = torch.nn.Sequential(
|
|
||||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, mask):
|
|
||||||
output = self.block(x * mask)
|
|
||||||
return output * mask
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1dBlock(nn.Module):
|
class Conv1dBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Conv1d --> GroupNorm --> Mish
|
Conv1d --> GroupNorm --> Mish
|
||||||
|
|
|
@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import LinearAttention
|
from .attention import LinearAttention
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample
|
from .resnet import Downsample, ResnetBlock, Upsample
|
||||||
from .resnet import ResnetBlockGradTTS as ResnetBlock
|
|
||||||
from .resnet import Upsample
|
|
||||||
|
|
||||||
|
|
||||||
class Mish(torch.nn.Module):
|
class Mish(torch.nn.Module):
|
||||||
|
@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
self.downs.append(
|
self.downs.append(
|
||||||
torch.nn.ModuleList(
|
torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
ResnetBlock(
|
||||||
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
in_channels=dim_in,
|
||||||
|
out_channels=dim_out,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
),
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=dim_out,
|
||||||
|
out_channels=dim_out,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
),
|
||||||
Residual(Rezero(LinearAttention(dim_out))),
|
Residual(Rezero(LinearAttention(dim_out))),
|
||||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||||
]
|
]
|
||||||
|
@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
self.mid_block1 = ResnetBlock(
|
||||||
|
in_channels=mid_dim,
|
||||||
|
out_channels=mid_dim,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
)
|
||||||
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
self.mid_block2 = ResnetBlock(
|
||||||
|
in_channels=mid_dim,
|
||||||
|
out_channels=mid_dim,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
torch.nn.ModuleList(
|
torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
ResnetBlock(
|
||||||
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
in_channels=dim_out * 2,
|
||||||
|
out_channels=dim_in,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
),
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=dim_in,
|
||||||
|
out_channels=dim_in,
|
||||||
|
temb_channels=dim,
|
||||||
|
groups=8,
|
||||||
|
pre_norm=False,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="mish",
|
||||||
|
overwrite_for_grad_tts=True,
|
||||||
|
),
|
||||||
Residual(Rezero(LinearAttention(dim_in))),
|
Residual(Rezero(LinearAttention(dim_in))),
|
||||||
Upsample(dim_in, use_conv_transpose=True),
|
Upsample(dim_in, use_conv_transpose=True),
|
||||||
]
|
]
|
||||||
|
@ -135,8 +187,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
masks = [mask]
|
masks = [mask]
|
||||||
for resnet1, resnet2, attn, downsample in self.downs:
|
for resnet1, resnet2, attn, downsample in self.downs:
|
||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x = resnet1(x, mask_down, t)
|
x = resnet1(x, t, mask_down)
|
||||||
x = resnet2(x, mask_down, t)
|
x = resnet2(x, t, mask_down)
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x * mask_down)
|
x = downsample(x * mask_down)
|
||||||
|
@ -144,15 +196,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
|
|
||||||
masks = masks[:-1]
|
masks = masks[:-1]
|
||||||
mask_mid = masks[-1]
|
mask_mid = masks[-1]
|
||||||
x = self.mid_block1(x, mask_mid, t)
|
x = self.mid_block1(x, t, mask_mid)
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
x = self.mid_block2(x, mask_mid, t)
|
x = self.mid_block2(x, t, mask_mid)
|
||||||
|
|
||||||
for resnet1, resnet2, attn, upsample in self.ups:
|
for resnet1, resnet2, attn, upsample in self.ups:
|
||||||
mask_up = masks.pop()
|
mask_up = masks.pop()
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = resnet1(x, mask_up, t)
|
x = resnet1(x, t, mask_up)
|
||||||
x = resnet2(x, mask_up, t)
|
x = resnet2(x, t, mask_up)
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
x = upsample(x * mask_up)
|
x = upsample(x * mask_up)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue