make style
This commit is contained in:
parent
c174bcf4bf
commit
046dc43075
|
@ -166,8 +166,8 @@ class Downsample(nn.Module):
|
|||
#
|
||||
# class GlideUpsample(nn.Module):
|
||||
# """
|
||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
|
||||
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
|
||||
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||
#
|
||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
|
@ -192,8 +192,8 @@ class Downsample(nn.Module):
|
|||
#
|
||||
# class LDMUpsample(nn.Module):
|
||||
# """
|
||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
|
||||
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
|
||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||
#
|
||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
|
@ -342,7 +342,20 @@ class ResBlock(TimestepBlock):
|
|||
|
||||
# unet.py and unet_grad_tts.py
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
overwrite_for_grad_tts=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample
|
||||
from .resnet import ResnetBlock
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, ResnetBlock, Upsample
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
self.downs.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=dim_out,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
|
@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
self.mid_block1 = ResnetBlock(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
||||
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
self.mid_block2 = ResnetBlock(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(
|
||||
in_channels=dim_out * 2,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue