fix some tests on gpu
This commit is contained in:
parent
45a09bebf3
commit
1a0331a78a
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# unet.py
|
||||||
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
|
"""
|
||||||
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
|
From Fairseq.
|
||||||
|
Build sinusoidal embeddings.
|
||||||
|
This matches the implementation in tensor2tensor, but differs slightly
|
||||||
|
from the description in Section 3.5 of "Attention Is All You Need".
|
||||||
|
"""
|
||||||
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
|
emb = emb.to(device=timesteps.device)
|
||||||
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
# unet_glide.py
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||||
|
device=timesteps.device
|
||||||
|
)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
|
@ -198,6 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
if not isinstance(spk, type(None)):
|
if not isinstance(spk, type(None)):
|
||||||
s = self.spk_mlp(spk)
|
s = self.spk_mlp(spk)
|
||||||
|
|
||||||
|
|
||||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||||
t = self.mlp(t)
|
t = self.mlp(t)
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ class ModelTesterMixin:
|
||||||
new_image = new_model(**inputs_dict)
|
new_image = new_model(**inputs_dict)
|
||||||
|
|
||||||
max_diff = (image - new_image).abs().sum().item()
|
max_diff = (image - new_image).abs().sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
|
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||||
|
|
||||||
def test_determinism(self):
|
def test_determinism(self):
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
@ -431,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
|
emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
|
||||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(noise, time_step, emb)
|
output = model(noise, time_step, emb)
|
||||||
|
|
||||||
output, _ = torch.split(output, 3, dim=1)
|
output, _ = torch.split(output, 3, dim=1)
|
||||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
output_slice = output[0, -1, -3:, -3:].cpu().flatten()
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
|
expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
Loading…
Reference in New Issue