hot fix
This commit is contained in:
parent
c4ef1efe46
commit
cbbb29398a
|
@ -20,7 +20,7 @@ import torch
|
||||||
|
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.modeling_utils import ModelMixin
|
||||||
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
|
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from ..test_modeling_common import ModelTesterMixin
|
from ..test_modeling_common import ModelTesterMixin
|
||||||
|
@ -136,6 +136,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||||
|
def get_file_format(self, seed, shape):
|
||||||
|
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# clean up the VRAM after each test
|
# clean up the VRAM after each test
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
@ -143,11 +146,8 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||||
batch_size, channels, height, width = shape
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
|
||||||
dtype = torch.float16 if fp16 else torch.float32
|
dtype = torch.float16 if fp16 else torch.float32
|
||||||
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)
|
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||||
|
|
Loading…
Reference in New Issue