diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index de2d6aa2..7bd6414b 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -154,6 +154,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.pe_scale = pe_scale if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)) self.time_pos_emb = SinusoidalPosEmb(dim) @@ -189,6 +190,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.final_conv = torch.nn.Conv2d(dim, 1, 1) def forward(self, x, mask, mu, t, spk=None): + if self.n_spks > 1: + # Get speaker embedding + spk = self.spk_emb(spk) + if not isinstance(spk, type(None)): s = self.spk_mlp(spk)