Remove bogus file

This commit is contained in:
Patrick von Platen 2022-06-29 14:29:35 +00:00
parent 4e125f72ab
commit 466214d2d6
1 changed files with 7 additions and 19 deletions

View File

@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample
from .resnet import ResnetBlock as ResnetBlockNew
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import ResnetBlock
from .resnet import Upsample
@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out)
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
torch.nn.ModuleList(
[
# ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
# ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
torch.nn.ModuleList(
[
# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
# ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True),
]