diff --git a/modules/processing.py b/modules/processing.py index 0ff6a45c0..dc5382721 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,7 +115,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - if getattr(sd_model.model, "is_sdxl_inpaint", False): + if sd_model.is_sdxl_inpaint: # The "masked-image" in this case will just be all 0.5 since the entire image is masked. image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 image_conditioning = images_tensor_to_samples(image_conditioning, @@ -389,7 +389,7 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False): + if self.sampler.model_wrap.inner_model.is_sdxl_inpaint: return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. diff --git a/modules/sd_models.py b/modules/sd_models.py index 61bd15d8f..93ff6c5fe 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -386,13 +386,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() - # Set is_sdxl_inpaint flag. - diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) - model.is_sdxl_inpaint = ( - model.is_sdxl and - diffusion_model_input is not None and - diffusion_model_input.shape[1] == 9 - ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -408,6 +401,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer del state_dict + # Set is_sdxl_inpaint flag. + # Perform this check after model initialization to make sure state_dict + # structure is already known. + diffusion_model_input = model.model.state_dict().get( + 'diffusion_model.input_blocks.0.0.weight' + ) + model.is_sdxl_inpaint = ( + model.is_sdxl and + diffusion_model_input is not None and + diffusion_model_input.shape[1] == 9 + ) + if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) timer.record("apply channels_last")