hot fix
This commit is contained in:
parent
c4ef1efe46
commit
cbbb29398a
|
@ -20,7 +20,7 @@ import torch
|
|||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
@ -136,6 +136,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -143,11 +146,8 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
batch_size, channels, height, width = shape
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)
|
||||
|
||||
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||
|
|
Loading…
Reference in New Issue