From 2a8a60c2c50473f0ece5804d4a2cde0d1ff3d35e Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 16 May 2024 19:50:06 -0400 Subject: [PATCH 1/3] Add --precision half cmd option --- modules/cmd_args.py | 2 +- modules/devices.py | 24 ++++++++++++++++++++++++ modules/sd_hijack_unet.py | 29 ++++++++++++++++++++++------- modules/sd_hijack_utils.py | 26 +++++++++++++++----------- modules/sd_models.py | 1 + modules/shared_init.py | 8 ++++++++ 6 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 016a33d10..58c5e5d5b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") -parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) diff --git a/modules/devices.py b/modules/devices.py index e4f671ac6..7de34ac51 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") fp8: bool = False +# Force fp16 for all models in inference. No casting during inference. +# This flag is controlled by "--precision half" command line arg. +force_fp16: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -127,6 +130,8 @@ unet_needs_upcast = False def cond_cast_unet(input): + if force_fp16: + return input.to(torch.float16) return input.to(dtype_unet) if unet_needs_upcast else input @@ -206,6 +211,11 @@ def autocast(disable=False): if disable: return contextlib.nullcontext() + if force_fp16: + # No casting during inference if force_fp16 is enabled. + # All tensor dtype conversion happens before inference. + return contextlib.nullcontext() + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) @@ -269,3 +279,17 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + + +def force_model_fp16(): + """ + ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which + force conversion of input to float32. If force_fp16 is enabled, we need to + prevent this casting. + """ + assert force_fp16 + import sgm.modules.diffusionmodules.util as sgm_util + import ldm.modules.diffusionmodules.util as ldm_util + sgm_util.GroupNorm32 = torch.nn.GroupNorm + ldm_util.GroupNorm32 = torch.nn.GroupNorm + print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.") diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 2101f1a04..41955313a 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,7 +36,7 @@ th = TorchHijackForUnet() # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - + """Always make sure inputs to unet are in correct dtype.""" if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): @@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + if devices.unet_needs_upcast: + return result.float() + else: + return result class GELUHijack(torch.nn.GELU, torch.nn.Module): @@ -64,12 +68,11 @@ def hijack_ddpm_edit(): if not ddpm_edit_hijack: CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) + if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) @@ -81,5 +84,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) + + +def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): + if devices.unet_needs_upcast and timesteps.dtype == torch.int64: + dtype = torch.float32 + else: + dtype = devices.dtype_unet + return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) + + +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index 79bf6e468..546f2eda4 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -1,7 +1,11 @@ import importlib + +always_true_func = lambda *args, **kwargs: True + + class CondFunc: - def __new__(cls, orig_func, sub_func, cond_func): + def __new__(cls, orig_func, sub_func, cond_func=always_true_func): self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') @@ -20,13 +24,13 @@ class CondFunc: print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") pass self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a6..9c5909168 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,6 +403,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 + assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half" timer.record("apply float()") else: vae = model.first_stage_model diff --git a/modules/shared_init.py b/modules/shared_init.py index 935e3a21c..a6ad0433d 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -31,6 +31,14 @@ def initialize(): devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype + if cmd_opts.precision == "half": + msg = "--no-half and --no-half-vae conflict with --precision half" + assert devices.dtype == torch.float16, msg + assert devices.dtype_vae == torch.float16, msg + assert devices.dtype_inference == torch.float16, msg + devices.force_fp16 = True + devices.force_model_fp16() + shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From dca9007ac7a9852752d91d34d2ed1feaef6a03f2 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 17 May 2024 13:23:12 -0400 Subject: [PATCH 2/3] Fix SD15 dtype --- modules/sd_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9c5909168..7d4ab0fd8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -733,6 +733,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config + # ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite + # UnetModel.dtype, it will be the default dtype from config. + # sgm's Unet is not using dtype for casting. The value will be ignored. + sd_model.model.diffusion_model.dtype = devices.dtype_unet timer.record("create model") From b57a70f37322142939f7429f287599e027108bfc Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 17 May 2024 13:34:04 -0400 Subject: [PATCH 3/3] Proper fix of SD15 dtype --- modules/sd_models.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7d4ab0fd8..26a5127cd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -541,7 +541,7 @@ def repair_config(sd_config): if hasattr(sd_config.model.params, 'unet_config'): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: + elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: @@ -733,10 +733,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config - # ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite - # UnetModel.dtype, it will be the default dtype from config. - # sgm's Unet is not using dtype for casting. The value will be ignored. - sd_model.model.diffusion_model.dtype = devices.dtype_unet timer.record("create model")