add speaker emb in unet

This commit is contained in:
patil-suraj 2022-06-16 16:48:00 +02:00
parent 3f2d46a14e
commit 71ecc7aed8
1 changed files with 5 additions and 0 deletions

View File

@ -154,6 +154,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.pe_scale = pe_scale
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
torch.nn.Linear(spk_emb_dim * 4, n_feats))
self.time_pos_emb = SinusoidalPosEmb(dim)
@ -189,6 +190,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None):
if self.n_spks > 1:
# Get speaker embedding
spk = self.spk_emb(spk)
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)