diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b138..6d6575119 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -138,6 +138,7 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) +# Always make sure inputs to unet are in correct dtype CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) @@ -150,5 +151,6 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) +# Always make sure timestep calculation is in correct dtype CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9cec4f13d..928beb57e 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -54,14 +54,19 @@ def is_using_v_parameterization_for_sd2(state_dict): unet.eval() with torch.no_grad(): + unet_dtype = torch.float + original_unet_dtype = devices.dtype_unet + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=torch.float) + unet.to(device=device, dtype=unet_dtype) + devices.dtype_unet = unet_dtype test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + devices.dtype_unet = original_unet_dtype return out < -1