From 8a6833b85c1e781d807d0991eb6d6f6f645e309e Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Tue, 20 Sep 2022 19:10:44 +0200 Subject: [PATCH] Add the K-LMS scheduler to the inpainting pipeline + tests (#587) * Add the K-LMS scheduler to the inpainting pipeline + tests * Remove redundant casts --- .../pipeline_stable_diffusion_img2img.py | 6 +-- .../pipeline_stable_diffusion_inpaint.py | 29 +++++++++--- tests/test_pipelines.py | 47 +++++++++++++++++++ 3 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index bbb553f1..e2affac6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -213,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings text_input = self.tokenizer( @@ -265,8 +265,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): sigma = self.scheduler.sigmas[t_index] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -284,7 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(self.vae.dtype)).sample + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b152c77c..8d18d2f3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -12,7 +12,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -78,7 +78,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -241,8 +241,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -287,8 +292,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + t_index = t_start + i # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -299,10 +309,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # scale and decode the image latents with vae diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a9..7689126b 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1325,6 +1325,53 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_inpaint_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=self.dummy_safety_checker, + use_auth_token=True, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + @slow def test_stable_diffusion_onnx(self): from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models