[CLIPGuidedStableDiffusion] support DDIM scheduler (#1190)
add ddim in clip guided
This commit is contained in:
parent
663f0c1963
commit
cd77a03651
|
@ -5,7 +5,14 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
|||
clip_model: CLIPModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
|||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
if isinstance(self.scheduler, PNDMScheduler):
|
||||
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# compute predicted original sample from predicted noise also called
|
||||
|
@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
|||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
clip_guidance_scale: Optional[float] = 100,
|
||||
clip_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_cutouts: Optional[int] = 4,
|
||||
|
@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
|||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
|||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
|
Loading…
Reference in New Issue