Remove bogus file
This commit is contained in:
parent
4e125f72ab
commit
466214d2d6
|
@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample
|
||||
from .resnet import ResnetBlock as ResnetBlockNew
|
||||
from .resnet import ResnetBlockGradTTS as ResnetBlock
|
||||
from .resnet import ResnetBlock
|
||||
from .resnet import Upsample
|
||||
|
||||
|
||||
|
@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
self.ups = torch.nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
# num_groups = 8
|
||||
# self.pre_norm = False
|
||||
# eps = 1e-5
|
||||
# non_linearity = "mish"
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
self.downs.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
# ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
||||
# ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
||||
ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
|
@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
||||
self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
||||
# ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
||||
ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue