refactor all sinus embeddings
This commit is contained in:
parent
02a76c2c81
commit
c7a39d38ad
|
@ -11,15 +11,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000):
|
||||
def get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
|
||||
|
@ -31,16 +32,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
|
|||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift))
|
||||
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
emb = torch.exp(emb * emb_coeff)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
|
@ -52,96 +57,6 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
|
|||
return emb
|
||||
|
||||
|
||||
#def get_timestep_embedding(timesteps, embedding_dim):
|
||||
# """
|
||||
# This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
# From Fairseq.
|
||||
# Build sinusoidal embeddings.
|
||||
# This matches the implementation in tensor2tensor, but differs slightly
|
||||
# from the description in Section 3.5 of "Attention Is All You Need".
|
||||
# """
|
||||
# assert len(timesteps.shape) == 1
|
||||
#
|
||||
# half_dim = embedding_dim // 2
|
||||
# emb = math.log(10000) / (half_dim - 1)
|
||||
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
# emb = emb.to(device=timesteps.device)
|
||||
# emb = timesteps.float()[:, None] * emb[None, :]
|
||||
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
# if embedding_dim % 2 == 1: # zero pad
|
||||
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
|
||||
|
||||
#def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
#
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
# device=timesteps.device
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None, :]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# return embedding
|
||||
|
||||
|
||||
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
||||
# half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
# emb = math.log(max_positions) / (half_dim - 1)
|
||||
# emb = math.log(2.) / (half_dim - 1)
|
||||
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
# emb = timesteps.float()[:, None] * emb[None, :]
|
||||
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
# if embedding_dim % 2 == 1: # zero pad
|
||||
# emb = F.pad(emb, (0, 1), mode="constant")
|
||||
# assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
# return emb
|
||||
|
||||
|
||||
# unet_grad_tts.py
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(SinusoidalPosEmb, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
# unet_sde_score_estimation.py
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian Fourier embeddings for noise levels."""
|
||||
|
@ -153,3 +68,19 @@ class GaussianFourierProjection(nn.Module):
|
|||
def forward(self, x):
|
||||
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
|
||||
|
||||
# unet_rl.py - TODO(need test)
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
|
|
@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
|
|||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
#def get_timestep_embedding(timesteps, embedding_dim):
|
||||
# """
|
||||
# This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
# From Fairseq.
|
||||
# Build sinusoidal embeddings.
|
||||
# This matches the implementation in tensor2tensor, but differs slightly
|
||||
# from the description in Section 3.5 of "Attention Is All You Need".
|
||||
# """
|
||||
# assert len(timesteps.shape) == 1
|
||||
#
|
||||
# half_dim = embedding_dim // 2
|
||||
# emb = math.log(10000) / (half_dim - 1)
|
||||
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
# emb = emb.to(device=timesteps.device)
|
||||
# emb = timesteps.float()[:, None] * emb[None, :]
|
||||
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
# if embedding_dim % 2 == 1: # zero pad
|
||||
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
# return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
|
|
@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
|
|||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
# def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
#
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
# device=timesteps.device
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
|
@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
|
@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
|
|||
|
||||
def forward(self, x, timesteps, transformer_out=None):
|
||||
hs = []
|
||||
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
# project the last token
|
||||
transformer_proj = self.transformer_proj(transformer_out[:, -1])
|
||||
|
@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
|
|||
x = torch.cat([x, upsampled], dim=1)
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -11,6 +9,7 @@ except:
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(SinusoidalPosEmb, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
||||
super(UNetGradTTSModel, self).__init__()
|
||||
|
@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
||||
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
|
||||
|
@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
|
||||
|
||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
if self.n_spks < 2:
|
||||
|
|
|
@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
|
|||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
#def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
#
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
# device=timesteps.device
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# return embedding
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
|
@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
|
|||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
|
|
|
@ -382,23 +382,6 @@ def get_act(nonlinearity):
|
|||
raise NotImplementedError("activation function does not exist!")
|
||||
|
||||
|
||||
#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
||||
# half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
# emb = math.log(max_positions) / (half_dim - 1)
|
||||
# emb = math.log(2.) / (half_dim - 1)
|
||||
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
# emb = timesteps.float()[:, None] * emb[None, :]
|
||||
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
# if embedding_dim % 2 == 1: # zero pad
|
||||
# emb = F.pad(emb, (0, 1), mode="constant")
|
||||
# assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
# return emb
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
|
|
@ -21,8 +21,7 @@ import unittest
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
|
||||
from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
|
@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
|||
|
||||
|
||||
class EmbeddingsTests(unittest.TestCase):
|
||||
|
||||
def test_timestep_embeddings(self):
|
||||
embedding_dim = 256
|
||||
timesteps = torch.arange(16)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim)
|
||||
|
||||
# first vector should always be composed only of 0's and 1's
|
||||
assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5
|
||||
assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5
|
||||
|
||||
# last element of each vector should be one
|
||||
assert (t1[:, -1] - 1).abs().sum() < 1e-5
|
||||
|
||||
# For large embeddings (e.g. 128) the frequency of every vector is higher
|
||||
# than the previous one which means that the gradients of later vectors are
|
||||
# ALWAYS higher than the previous ones
|
||||
grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1)
|
||||
|
||||
prev_grad = 0.0
|
||||
for grad in grad_mean:
|
||||
assert grad > prev_grad
|
||||
prev_grad = grad
|
||||
|
||||
def test_timestep_defaults(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim)
|
||||
t2 = timestep_embedding(timesteps, embedding_dim)
|
||||
t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8)
|
||||
t2 = get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
|
||||
)
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
|
||||
|
||||
def test_timestep_flip_sin_cos(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True)
|
||||
t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1)
|
||||
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False)
|
||||
|
||||
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
|
||||
|
||||
def test_timestep_downscale_freq_shift(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0)
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1)
|
||||
|
||||
# get cosine half (vectors that are wrapped into cosine)
|
||||
cosine_half = (t1 - t2)[:, embedding_dim // 2 :]
|
||||
|
||||
# cosine needs to be negative
|
||||
assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5
|
||||
|
||||
def test_sinoid_embeddings_hardcoded(self):
|
||||
embedding_dim = 64
|
||||
timesteps = torch.arange(128)
|
||||
|
||||
# standard unet, score_vde
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False)
|
||||
# glide, ldm
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True)
|
||||
# grad-tts
|
||||
t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000)
|
||||
|
||||
assert torch.allclose(
|
||||
t1[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]),
|
||||
1e-3,
|
||||
)
|
||||
assert torch.allclose(
|
||||
t2[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]),
|
||||
1e-3,
|
||||
)
|
||||
assert torch.allclose(
|
||||
t3[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
|
||||
1e-3,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue