diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 016a33d10..58c5e5d5b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") -parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) diff --git a/modules/devices.py b/modules/devices.py index e4f671ac6..7de34ac51 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") fp8: bool = False +# Force fp16 for all models in inference. No casting during inference. +# This flag is controlled by "--precision half" command line arg. +force_fp16: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -127,6 +130,8 @@ unet_needs_upcast = False def cond_cast_unet(input): + if force_fp16: + return input.to(torch.float16) return input.to(dtype_unet) if unet_needs_upcast else input @@ -206,6 +211,11 @@ def autocast(disable=False): if disable: return contextlib.nullcontext() + if force_fp16: + # No casting during inference if force_fp16 is enabled. + # All tensor dtype conversion happens before inference. + return contextlib.nullcontext() + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) @@ -269,3 +279,17 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + + +def force_model_fp16(): + """ + ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which + force conversion of input to float32. If force_fp16 is enabled, we need to + prevent this casting. + """ + assert force_fp16 + import sgm.modules.diffusionmodules.util as sgm_util + import ldm.modules.diffusionmodules.util as ldm_util + sgm_util.GroupNorm32 = torch.nn.GroupNorm + ldm_util.GroupNorm32 = torch.nn.GroupNorm + print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.") diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 2101f1a04..41955313a 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,7 +36,7 @@ th = TorchHijackForUnet() # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - + """Always make sure inputs to unet are in correct dtype.""" if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): @@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + if devices.unet_needs_upcast: + return result.float() + else: + return result class GELUHijack(torch.nn.GELU, torch.nn.Module): @@ -64,12 +68,11 @@ def hijack_ddpm_edit(): if not ddpm_edit_hijack: CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) + if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) @@ -81,5 +84,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) + + +def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): + if devices.unet_needs_upcast and timesteps.dtype == torch.int64: + dtype = torch.float32 + else: + dtype = devices.dtype_unet + return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) + + +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index 79bf6e468..546f2eda4 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -1,7 +1,11 @@ import importlib + +always_true_func = lambda *args, **kwargs: True + + class CondFunc: - def __new__(cls, orig_func, sub_func, cond_func): + def __new__(cls, orig_func, sub_func, cond_func=always_true_func): self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') @@ -20,13 +24,13 @@ class CondFunc: print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") pass self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a6..9c5909168 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,6 +403,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 + assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half" timer.record("apply float()") else: vae = model.first_stage_model diff --git a/modules/shared_init.py b/modules/shared_init.py index 935e3a21c..a6ad0433d 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -31,6 +31,14 @@ def initialize(): devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype + if cmd_opts.precision == "half": + msg = "--no-half and --no-half-vae conflict with --precision half" + assert devices.dtype == torch.float16, msg + assert devices.dtype_vae == torch.float16, msg + assert devices.dtype_inference == torch.float16, msg + devices.force_fp16 = True + devices.force_model_fp16() + shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu"