refactor all sinus embeddings

This commit is contained in:
Patrick von Platen 2022-06-27 11:37:37 +00:00
parent 02a76c2c81
commit c7a39d38ad
7 changed files with 123 additions and 210 deletions

View File

@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000):
def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -31,16 +32,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift))
emb = emb.to(device=timesteps.device)
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
@ -52,96 +57,6 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
return emb
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# unet_rl.py
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
@ -153,3 +68,19 @@ class GaussianFourierProjection(nn.Module):
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)

View File

@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
# def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype)
for module in self.input_blocks:
@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None):
hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1])
@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1)
hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x
for module in self.input_blocks:

View File

@ -1,5 +1,3 @@
import math
import torch
@ -11,6 +9,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
class Mish(torch.nn.Module):
@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__()
@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t)
if self.n_spks < 2:

View File

@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
## go
class AttentionPool2d(nn.Module):
"""
@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
results = []
h = x.type(self.dtype)

View File

@ -382,23 +382,6 @@ def get_act(nonlinearity):
raise NotImplementedError("activation function does not exist!")
#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale

View File

@ -21,8 +21,7 @@ import unittest
import numpy as np
import torch
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
class EmbeddingsTests(unittest.TestCase):
def test_timestep_embeddings(self):
embedding_dim = 256
timesteps = torch.arange(16)
t1 = get_timestep_embedding(timesteps, embedding_dim)
# first vector should always be composed only of 0's and 1's
assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5
assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5
# last element of each vector should be one
assert (t1[:, -1] - 1).abs().sum() < 1e-5
# For large embeddings (e.g. 128) the frequency of every vector is higher
# than the previous one which means that the gradients of later vectors are
# ALWAYS higher than the previous ones
grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1)
prev_grad = 0.0
for grad in grad_mean:
assert grad > prev_grad
prev_grad = grad
def test_timestep_defaults(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim)
t2 = timestep_embedding(timesteps, embedding_dim)
t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8)
t2 = get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
)
import ipdb; ipdb.set_trace()
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
def test_timestep_flip_sin_cos(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True)
t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1)
t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False)
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
def test_timestep_downscale_freq_shift(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0)
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1)
# get cosine half (vectors that are wrapped into cosine)
cosine_half = (t1 - t2)[:, embedding_dim // 2 :]
# cosine needs to be negative
assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5
def test_sinoid_embeddings_hardcoded(self):
embedding_dim = 64
timesteps = torch.arange(128)
# standard unet, score_vde
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False)
# glide, ldm
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True)
# grad-tts
t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000)
assert torch.allclose(
t1[23:26, 47:50].flatten().cpu(),
torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]),
1e-3,
)
assert torch.allclose(
t2[23:26, 47:50].flatten().cpu(),
torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]),
1e-3,
)
assert torch.allclose(
t3[23:26, 47:50].flatten().cpu(),
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
1e-3,
)