From ffead92d4e36a5082fa6ac5dd54c88477c9b524e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 6 Jul 2024 10:40:48 +0300 Subject: [PATCH] Revert "Merge pull request #16078 from huchenlei/fix_sd2" This reverts commit 4cc3add770b10cb8e8f7aa980c0d50e5b637ab2b, reversing changes made to 50514ce414ee4fad9aa4780ef0b97116c7d7c970. --- modules/sd_hijack_unet.py | 2 -- modules/sd_models_config.py | 7 +------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 6d6575119..b4f03b138 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -138,7 +138,6 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -# Always make sure inputs to unet are in correct dtype CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) @@ -151,6 +150,5 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) -# Always make sure timestep calculation is in correct dtype CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index e9a80ebaf..7cfeca67f 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -56,19 +56,14 @@ def is_using_v_parameterization_for_sd2(state_dict): unet.eval() with torch.no_grad(): - unet_dtype = torch.float - original_unet_dtype = devices.dtype_unet - unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=unet_dtype) - devices.dtype_unet = unet_dtype + unet.to(device=device, dtype=torch.float) test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() - devices.dtype_unet = original_unet_dtype return out < -1