[Img2Img2] Re-add K LMS scheduler (#340)
This commit is contained in:
parent
e49dd03d2d
commit
9b704f7688
|
@ -5,12 +5,11 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import PIL
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
|
@ -31,7 +30,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
|
@ -93,12 +92,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
# get the original timestep using init_timestep
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
|
@ -137,11 +141,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
|
||||
t_index = t_start + i
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = latent_model_input.to(self.unet.dtype)
|
||||
t = t.to(self.unet.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
|
||||
|
@ -151,11 +166,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae.decode(latents.to(self.vae.dtype))
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
|
|
@ -445,6 +445,49 @@ class PipelineFastTests(unittest.TestCase):
|
|||
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
init_image = self.dummy_image.to(device)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
)
|
||||
|
||||
image = output["sample"]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
|
@ -892,7 +935,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
def test_stable_diffusion_img2img_pipeline(self):
|
||||
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
|
||||
|
||||
init_image = ds[1]["image"].resize((768, 512))
|
||||
init_image = ds[2]["image"].resize((768, 512))
|
||||
output_image = ds[0]["image"].resize((768, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
@ -915,12 +958,40 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_in_paint_pipeline(self):
|
||||
def test_stable_diffusion_img2img_pipeline_k_lms(self):
|
||||
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
|
||||
|
||||
init_image = ds[2]["image"].resize((768, 512))
|
||||
mask_image = ds[3]["image"].resize((768, 512))
|
||||
output_image = ds[4]["image"].resize((768, 512))
|
||||
output_image = ds[1]["image"].resize((768, 512))
|
||||
|
||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, scheduler=lms, use_auth_token=True)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
|
||||
"sample"
|
||||
][0]
|
||||
|
||||
expected_array = np.array(output_image)
|
||||
sampled_array = np.array(image)
|
||||
|
||||
assert sampled_array.shape == (512, 768, 3)
|
||||
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_in_paint_pipeline(self):
|
||||
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
|
||||
|
||||
init_image = ds[3]["image"].resize((768, 512))
|
||||
mask_image = ds[4]["image"].resize((768, 512))
|
||||
output_image = ds[5]["image"].resize((768, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
|
||||
|
|
Loading…
Reference in New Issue