From ee71d9d03ddb1489b87662cd6f69a6d3ef70b6a9 Mon Sep 17 00:00:00 2001 From: clarencechen Date: Tue, 14 Mar 2023 13:25:12 -0700 Subject: [PATCH] Add support for different model prediction types in DDIMInverseScheduler (#2619) * Add support for different model prediction types in DDIMInverseScheduler Resolve alpha_prod_t_prev index issue for final step of inversion * Fix old bug introduced when prediction type is "sample" * Add support for sample clipping for numerical stability and deprecate old kwarg * Detach sample, alphas, betas Derive predicted noise from model output before dist. regularization Style cleanup * Log loss for debugging * Revert "Log loss for debugging" This reverts commit 76ea9c856f99f4c8eca45a0b1801593bb982584b. * Add comments * Add inversion equivalence test * Add expected data for Pix2PixZero pipeline tests with SD 2 * Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py * Remove cruft and add more explanatory comments --------- Co-authored-by: Patrick von Platen --- .../pipeline_stable_diffusion_pix2pix_zero.py | 34 +++++++- .../schedulers/scheduling_ddim_inverse.py | 87 ++++++++++++++----- .../test_stable_diffusion_pix2pix_zero.py | 65 +++++++++++++- 3 files changed, 157 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index b5a352c7..0e58701d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -153,6 +153,8 @@ EXAMPLE_INVERT_DOC_STRING = """ >>> source_embeds = pipeline.get_embeds(source_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts) >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion >>> image = pipeline( ... caption, @@ -730,6 +732,23 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): return latents + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + def auto_corr_loss(self, hidden_states, generator=None): batch_size, channel, height, width = hidden_states.shape if batch_size > 1: @@ -1156,8 +1175,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - with self.progress_bar(total=num_inference_steps - 2) as progress_bar: - for i, t in enumerate(timesteps[1:-1]): + with self.progress_bar(total=num_inference_steps - 1) as progress_bar: + for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) @@ -1181,7 +1200,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - l_ac = self.auto_corr_loss(var, generator=generator) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls @@ -1190,7 +1213,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - l_kld = self.kl_divergence(var) + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 7006bd13..2c9fc036 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -23,7 +23,7 @@ import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils import BaseOutput +from diffusers.utils import BaseOutput, deprecate @dataclass @@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_zero (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`, + otherwise it uses the value of alpha at step `num_train_timesteps - 1`. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. + `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha + product. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 @@ -122,10 +124,18 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, - set_alpha_to_one: bool = True, + set_alpha_to_zero: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + **kwargs, ): + if kwargs.get("set_alpha_to_one", None) is not None: + deprecation_message = ( + "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead." + ) + deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False) + set_alpha_to_zero = kwargs["set_alpha_to_one"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -144,11 +154,12 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # At every step in inverted ddim, we are looking into the next alphas_cumprod + # For the final step, there is no next alphas_cumprod, and the index is out of bounds + # `set_alpha_to_zero` decides whether we set this parameter simply to zero + # in this case, self.step() just output the predicted noise + # or whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -157,6 +168,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -205,23 +217,52 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: - e_t = model_output - - x = sample + # 1. get previous step value (=t+1) prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - a_t = self.alphas_cumprod[timestep - 1] - a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod + # 2. compute alphas, betas + # change original implementation to exactly match noise levels for analogous forward process + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep < self.config.num_train_timesteps + else self.final_alpha_cumprod + ) - pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() + beta_prod_t = 1 - alpha_prod_t - dir_xt = (1.0 - a_prev).sqrt() * e_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) - prev_sample = a_prev.sqrt() * pred_x0 + dir_xt + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: - return (prev_sample, pred_x0) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) + return (prev_sample, pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 9e80ef74..3830426a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -347,7 +347,6 @@ class InversionPipelineSlowTests(unittest.TestCase): pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 ) - pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) caption = "a photography of a cat with flowers" @@ -366,6 +365,28 @@ class InversionPipelineSlowTests(unittest.TestCase): assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2 + def test_stable_diffusion_2_pix2pix_inversion(self): + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + + caption = "a photography of a cat with flowers" + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10) + inv_latents = output[0] + + image_slice = inv_latents[0, -3:, -3:, -1].flatten() + + assert inv_latents.shape == (1, 4, 64, 64) + expected_slice = np.array([0.7515, -0.2397, 0.4922, -0.9736, -0.7031, 0.4846, -1.0781, 1.1309, -0.6973]) + + assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2 + def test_stable_diffusion_pix2pix_full(self): # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png expected_image = load_numpy( @@ -375,7 +396,6 @@ class InversionPipelineSlowTests(unittest.TestCase): pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 ) - pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) caption = "a photography of a cat with flowers" @@ -407,3 +427,44 @@ class InversionPipelineSlowTests(unittest.TestCase): max_diff = np.abs(expected_image - image).mean() assert max_diff < 0.05 + + def test_stable_diffusion_2_pix2pix_full(self): + # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog_2.png + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog_2.npy" + ) + + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + + caption = "a photography of a cat with flowers" + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + output = pipe.invert(caption, image=self.raw_image, generator=generator) + inv_latents = output[0] + + source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + source_embeds = pipe.get_embeds(source_prompts) + target_embeds = pipe.get_embeds(target_prompts) + + image = pipe( + caption, + source_embeds=source_embeds, + target_embeds=target_embeds, + num_inference_steps=125, + cross_attention_guidance_amount=0.015, + generator=generator, + latents=inv_latents, + negative_prompt=caption, + output_type="np", + ).images + + max_diff = np.abs(expected_image - image).mean() + assert max_diff < 0.05