This commit is contained in:
Patrick von Platen 2022-10-28 16:55:21 +00:00
parent c4ef1efe46
commit cbbb29398a
1 changed files with 5 additions and 5 deletions

View File

@ -20,7 +20,7 @@ import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from parameterized import parameterized from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
@ -136,6 +136,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
@slow @slow
class AutoencoderKLIntegrationTests(unittest.TestCase): class AutoencoderKLIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
@ -143,11 +146,8 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
batch_size, channels, height, width = shape
generator = torch.Generator(device=torch_device).manual_seed(seed)
dtype = torch.float16 if fp16 else torch.float32 dtype = torch.float16 if fp16 else torch.float32
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype) image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image return image
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False): def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):