add speaker emb in unet
This commit is contained in:
parent
3f2d46a14e
commit
71ecc7aed8
|
@ -154,6 +154,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
self.pe_scale = pe_scale
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
|
||||
torch.nn.Linear(spk_emb_dim * 4, n_feats))
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
|
@ -189,6 +190,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, x, mask, mu, t, spk=None):
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spk = self.spk_emb(spk)
|
||||
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
|
|
Loading…
Reference in New Issue