From 9eb2f786316c0f7e94c3df5f5e8bda203e6b875d Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 16:32:29 -0400 Subject: [PATCH 1/3] Precompute is_sdxl_inpaint flag --- modules/processing.py | 28 +++++++++++----------------- modules/sd_models.py | 7 +++++++ modules/sd_models_xl.py | 9 ++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 76557dd7f..d82cb24fb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - sd = sd_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if sd_model.model.is_sdxl_inpaint: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -390,11 +387,8 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - sd = self.sampler.model_wrap.inner_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if self.sampler.model_wrap.inner_model.model.is_sdxl_inpaint: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a6..62e74d27a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -380,6 +380,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + # Set is_sdxl_inpaint flag. + diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) + model.is_sdxl_inpaint = ( + model.is_sdxl and + diffusion_model_input is not None and + diffusion_model_input.shape[1] == 9 + ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 94ff973fb..35e21f6e4 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -35,11 +35,10 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): - sd = self.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + """WARNING: This function is called once per denoising iteration. DO NOT add + expensive functionc calls such as `model.state_dict`. """ + if self.model.is_sdxl_inpaint: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From 6a48476502d6cdd19cb3d0c7f2a0b92aacd7c01f Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 16:54:26 -0400 Subject: [PATCH 2/3] Fix flag check for SD15 --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d82cb24fb..fff2595e7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,7 +115,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - if sd_model.model.is_sdxl_inpaint: + if getattr(sd_model.model, "is_sdxl_inpaint", False): # The "masked-image" in this case will just be all 0.5 since the entire image is masked. image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 image_conditioning = images_tensor_to_samples(image_conditioning, @@ -387,7 +387,7 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - if self.sampler.model_wrap.inner_model.model.is_sdxl_inpaint: + if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False): return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. From 3e20b36e8f1b26f24db0c149732fb5479bff68bc Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 17:27:01 -0400 Subject: [PATCH 3/3] Fix attr access --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 35e21f6e4..1242a5936 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -37,7 +37,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): """WARNING: This function is called once per denoising iteration. DO NOT add expensive functionc calls such as `model.state_dict`. """ - if self.model.is_sdxl_inpaint: + if self.is_sdxl_inpaint: x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond)