Add the K-LMS scheduler to the inpainting pipeline + tests (#587)
* Add the K-LMS scheduler to the inpainting pipeline + tests * Remove redundant casts
This commit is contained in:
parent
a45dca077c
commit
8a6833b85c
|
@ -213,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
|
@ -265,8 +265,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = latent_model_input.to(self.unet.dtype)
|
||||
t = t.to(self.unet.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
@ -284,7 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
|
|
@ -12,7 +12,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
@ -78,7 +78,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
|
@ -241,8 +241,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
|
@ -287,8 +292,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
latents = init_latents
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
t_index = t_start + i
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
@ -299,10 +309,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
|
|
|
@ -1325,6 +1325,53 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
|
||||
|
|
Loading…
Reference in New Issue