diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a9143053..2c26340f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import math + import numpy as np - +import torch from torch import nn -import torch.nn.functional as F -def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): +def get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 +): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -31,18 +32,22 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - assert len(timesteps.shape) == 1 + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift)) - emb = emb.to(device=timesteps.device) + emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) + emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + emb = torch.exp(emb * emb_coeff) emb = timesteps[:, None].float() * emb[None, :] - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + # scale embeddings + emb = scale * emb - # flip sine and cosine embeddings + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) @@ -52,96 +57,6 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down return emb -#def get_timestep_embedding(timesteps, embedding_dim): -# """ -# This matches the implementation in Denoising Diffusion Probabilistic Models: -# From Fairseq. -# Build sinusoidal embeddings. -# This matches the implementation in tensor2tensor, but differs slightly -# from the description in Section 3.5 of "Attention Is All You Need". -# """ -# assert len(timesteps.shape) == 1 -# -# half_dim = embedding_dim // 2 -# emb = math.log(10000) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) -# emb = emb.to(device=timesteps.device) -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - - -#def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None, :] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - -#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): -# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 -# half_dim = embedding_dim // 2 - # magic number 10000 is from transformers -# emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = F.pad(emb, (0, 1), mode="constant") -# assert emb.shape == (timesteps.shape[0], embedding_dim) -# return emb - - -# unet_grad_tts.py -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super(SinusoidalPosEmb, self).__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -# unet_rl.py -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - # unet_sde_score_estimation.py class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" @@ -153,3 +68,19 @@ class GaussianFourierProjection(nn.Module): def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +# unet_rl.py - TODO(need test) +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 7d5eebfd..1749def9 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -#def get_timestep_embedding(timesteps, embedding_dim): -# """ -# This matches the implementation in Denoising Diffusion Probabilistic Models: -# From Fairseq. -# Build sinusoidal embeddings. -# This matches the implementation in tensor2tensor, but differs slightly -# from the description in Section 3.5 of "Attention Is All You Need". -# """ -# assert len(timesteps.shape) == 1 -# -# half_dim = embedding_dim // 2 -# emb = math.log(10000) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) -# emb = emb.to(device=timesteps.device) -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) -# return emb - - def nonlinearity(x): # swish return x * torch.sigmoid(x) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 0e045377..c154db92 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -87,27 +87,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -# def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - def zero_module(module): """ Zero out the parameters of a module and return it. @@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x.type(self.dtype) for module in self.input_blocks: @@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): def forward(self, x, timesteps, transformer_out=None): hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) # project the last token transformer_proj = self.transformer_proj(transformer_out[:, -1]) @@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): x = torch.cat([x, upsampled], dim=1) hs = [] - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x for module in self.input_blocks: diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 81719f08..ccae3133 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,5 +1,3 @@ -import math - import torch @@ -11,6 +9,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding class Mish(torch.nn.Module): @@ -107,21 +106,6 @@ class Residual(torch.nn.Module): return output -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super(SinusoidalPosEmb, self).__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - class UNetGradTTSModel(ModelMixin, ConfigMixin): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): super(UNetGradTTSModel, self).__init__() @@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) ) - self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] @@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - - t = self.time_pos_emb(timesteps, scale=self.pe_scale) + t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale) + t = self.mlp(t) if self.n_spks < 2: diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index cfc200bf..da84391a 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -317,27 +317,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -#def timestep_embedding(timesteps, dim, max_period=10000): -# """ -# Create sinusoidal timestep embeddings. -# -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# half = dim // 2 -# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( -# device=timesteps.device -# ) -# args = timesteps[:, None].float() * freqs[None] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# return embedding - - ## go class AttentionPool2d(nn.Module): """ @@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) results = [] h = x.type(self.dtype) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 7d00eb21..0f0cc4b7 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -382,23 +382,6 @@ def get_act(nonlinearity): raise NotImplementedError("activation function does not exist!") -#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): -# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 -# half_dim = embedding_dim // 2 - # magic number 10000 is from transformers -# emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) -# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] -# emb = timesteps.float()[:, None] * emb[None, :] -# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) -# if embedding_dim % 2 == 1: # zero pad -# emb = F.pad(emb, (0, 1), mode="constant") -# assert emb.shape == (timesteps.shape[0], embedding_dim) -# return emb - - def default_init(scale=1.0): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 0b50e7bc..42a42610 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -21,8 +21,7 @@ import unittest import numpy as np import torch -#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding -from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding +from diffusers.models.embeddings import get_timestep_embedding from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False class EmbeddingsTests(unittest.TestCase): - def test_timestep_embeddings(self): + embedding_dim = 256 + timesteps = torch.arange(16) + + t1 = get_timestep_embedding(timesteps, embedding_dim) + + # first vector should always be composed only of 0's and 1's + assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5 + assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5 + + # last element of each vector should be one + assert (t1[:, -1] - 1).abs().sum() < 1e-5 + + # For large embeddings (e.g. 128) the frequency of every vector is higher + # than the previous one which means that the gradients of later vectors are + # ALWAYS higher than the previous ones + grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1) + + prev_grad = 0.0 + for grad in grad_mean: + assert grad > prev_grad + prev_grad = grad + + def test_timestep_defaults(self): embedding_dim = 16 timesteps = torch.arange(10) t1 = get_timestep_embedding(timesteps, embedding_dim) - t2 = timestep_embedding(timesteps, embedding_dim) - t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8) + t2 = get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000 + ) - import ipdb; ipdb.set_trace() + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + def test_timestep_flip_sin_cos(self): + embedding_dim = 16 + timesteps = torch.arange(10) + t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True) + t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1) + + t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False) + + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + + def test_timestep_downscale_freq_shift(self): + embedding_dim = 16 + timesteps = torch.arange(10) + + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0) + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1) + + # get cosine half (vectors that are wrapped into cosine) + cosine_half = (t1 - t2)[:, embedding_dim // 2 :] + + # cosine needs to be negative + assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5 + + def test_sinoid_embeddings_hardcoded(self): + embedding_dim = 64 + timesteps = torch.arange(128) + + # standard unet, score_vde + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False) + # glide, ldm + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True) + # grad-tts + t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000) + + assert torch.allclose( + t1[23:26, 47:50].flatten().cpu(), + torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]), + 1e-3, + ) + assert torch.allclose( + t2[23:26, 47:50].flatten().cpu(), + torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]), + 1e-3, + ) + assert torch.allclose( + t3[23:26, 47:50].flatten().cpu(), + torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]), + 1e-3, + )