fix some tests on gpu

This commit is contained in:
Patrick von Platen 2022-06-27 09:07:57 +00:00
parent 45a09bebf3
commit 1a0331a78a
3 changed files with 60 additions and 2 deletions

View File

@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# unet.py
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
# unet_glide.py
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

View File

@ -198,6 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)

View File

@ -113,7 +113,7 @@ class ModelTesterMixin:
new_image = new_model(**inputs_dict) new_image = new_model(**inputs_dict)
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_determinism(self): def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@ -431,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step, emb) output = model(noise, time_step, emb)
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].cpu().flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
# fmt: on # fmt: on