diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 2c86e913..14d9ee63 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -5,7 +5,14 @@ import torch from torch import nn from torch.nn import functional as F -from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from torchvision import transforms from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer @@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): clip_model: CLIPModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], feature_extractor: CLIPFeatureExtractor, ): super().__init__() @@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): # predict the noise residual noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, PNDMScheduler): + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called @@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, clip_guidance_scale: Optional[float] = 100, clip_prompt: Optional[Union[str, List[str]]] = None, num_cutouts: Optional[int] = 4, @@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents