[Tests] Fix slow tests (#1087)

This commit is contained in:
Patrick von Platen 2022-10-31 18:59:58 +01:00 committed by GitHub
parent 010bc4ea19
commit 17c2c0600b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 13 deletions

View File

@ -323,7 +323,7 @@ class ConfigMixin:
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__:
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}

View File

@ -90,17 +90,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
" Hub, it would be very nice if you could open a Pull request for the"
" `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:

View File

@ -640,13 +640,14 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe.scheduler = scheduler
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)