make style

This commit is contained in:
Patrick von Platen 2022-06-29 14:36:35 +00:00
parent c174bcf4bf
commit 046dc43075
2 changed files with 79 additions and 14 deletions

View File

@ -166,8 +166,8 @@ class Downsample(nn.Module):
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
@ -192,8 +192,8 @@ class Downsample(nn.Module):
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
@ -342,7 +342,20 @@ class ResBlock(TimestepBlock):
# unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels

View File

@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample
from .resnet import ResnetBlock
from .resnet import Upsample
from .resnet import Downsample, ResnetBlock, Upsample
class Mish(torch.nn.Module):
@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_out,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block1 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block2 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(
in_channels=dim_out * 2,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True),
]