diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e374e3ae..2f4d2ab6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel +from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc98e2bb..9104bb90 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,3 +19,4 @@ from .unet import UNetModel from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_ldm import UNetLDMModel +from .unet_grad_tts import UNetGradTTSModel \ No newline at end of file diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py new file mode 100644 index 00000000..de2d6aa2 --- /dev/null +++ b/src/diffusers/models/unet_grad_tts.py @@ -0,0 +1,233 @@ +import math + +import torch + +try: + from einops import rearrange, repeat +except: + print("Einops is not installed") + pass + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Upsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(torch.nn.Module): + def __init__(self, dim): + super(Downsample, self).__init__() + self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(torch.nn.Module): + def __init__(self, fn): + super(Rezero, self).__init__() + self.fn = fn + self.g = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class Block(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, + padding=1), torch.nn.GroupNorm( + groups, dim_out), Mish()) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super(ResnetBlock, self).__init__() + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, + dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) + else: + self.res_conv = torch.nn.Identity() + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class LinearAttention(torch.nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class Residual(torch.nn.Module): + def __init__(self, fn): + super(Residual, self).__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + output = self.fn(x, *args, **kwargs) + x + return output + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class UNetGradTTSModel(ModelMixin, ConfigMixin): + def __init__( + self, + dim, + dim_mults=(1, 2, 4), + groups=8, + n_spks=None, + spk_emb_dim=64, + n_feats=80, + pe_scale=1000 + ): + super(UNetGradTTSModel, self).__init__() + + self.register( + dim=dim, + dim_mults=dim_mults, + groups=groups, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + n_feats=n_feats, + pe_scale=pe_scale + ) + + self.dim = dim + self.dim_mults = dim_mults + self.groups = groups + self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 + self.spk_emb_dim = spk_emb_dim + self.pe_scale = pe_scale + + if n_spks > 1: + self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), + torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.time_pos_emb = SinusoidalPosEmb(dim) + self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), + torch.nn.Linear(dim * 4, dim)) + + dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + self.downs = torch.nn.ModuleList([]) + self.ups = torch.nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + self.downs.append(torch.nn.ModuleList([ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else torch.nn.Identity()])) + + mid_dim = dims[-1] + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + self.ups.append(torch.nn.ModuleList([ + ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in)])) + self.final_block = Block(dim, dim) + self.final_conv = torch.nn.Conv2d(dim, 1, 1) + + def forward(self, x, mask, mu, t, spk=None): + if not isinstance(spk, type(None)): + s = self.spk_mlp(spk) + + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + if self.n_spks < 2: + x = torch.stack([mu, x], 1) + else: + s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) + x = torch.stack([mu, x, s], 1) + mask = mask.unsqueeze(1) + + hiddens = [] + masks = [mask] + for resnet1, resnet2, attn, downsample in self.downs: + mask_down = masks[-1] + x = resnet1(x, mask_down, t) + x = resnet2(x, mask_down, t) + x = attn(x) + hiddens.append(x) + x = downsample(x * mask_down) + masks.append(mask_down[:, :, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + x = self.mid_block1(x, mask_mid, t) + x = self.mid_attn(x) + x = self.mid_block2(x, mask_mid, t) + + for resnet1, resnet2, attn, upsample in self.ups: + mask_up = masks.pop() + x = torch.cat((x, hiddens.pop()), dim=1) + x = resnet1(x, mask_up, t) + x = resnet2(x, mask_up, t) + x = attn(x) + x = upsample(x * mask_up) + + x = self.final_block(x, mask) + output = self.final_conv(x * mask) + + return (output * mask).squeeze(1) \ No newline at end of file diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py new file mode 100644 index 00000000..2d8f6946 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -0,0 +1,385 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin +from diffusers.modeling_utils import ModelMixin + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + while True: + if length % (2**num_downsamplings_in_unet) == 0: + return length + length += 1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], + [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) + return loss + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super(LayerNorm, self).__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean)**2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, + n_layers, p_dropout): + super(ConvReluNorm, self).__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super(DurationPredictor, self).__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, + kernel_size, padding=kernel_size//2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, + kernel_size, padding=kernel_size//2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, window_size=None, + heads_share=True, p_dropout=0.0, proximal_bias=False, + proximal_init=False): + super(MultiHeadAttention, self).__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, + dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, + value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = torch.nn.functional.pad( + relative_embeddings, convert_pad_shape([[0, 0], + [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, + p_dropout=0.0): + super(FFN, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, + padding=kernel_size//2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, + padding=kernel_size//2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, + kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): + super(Encoder, self).__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, + n_heads, window_size=window_size, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, + filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(ModelMixin, ConfigMixin): + def __init__(self, n_vocab, n_feats, n_channels, filter_channels, + filter_channels_dp, n_heads, n_layers, kernel_size, + p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): + super(TextEncoder, self).__init__() + + self.register( + n_vocab=n_vocab, + n_feats=n_feats, + n_channels=n_channels, + filter_channels=filter_channels, + filter_channels_dp=filter_channels_dp, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout, + window_size=window_size, + spk_emb_dim=spk_emb_dim, + n_spks=n_spks + ) + + + self.n_vocab = n_vocab + self.n_feats = n_feats + self.n_channels = n_channels + self.filter_channels = filter_channels + self.filter_channels_dp = filter_channels_dp + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) + + self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, + kernel_size=5, n_layers=3, p_dropout=0.5) + + self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, + kernel_size, p_dropout, window_size=window_size) + + self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) + self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, + kernel_size, p_dropout) + + def forward(self, x, x_lengths, spk=None): + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask