Merge pull request #15820 from huchenlei/force_half
[Performance 6/6] Add --precision half option to avoid casting during inference
This commit is contained in:
commit
33b73c473c
|
@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus
|
||||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
|
||||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||||
|
|
|
@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
cpu: torch.device = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
fp8: bool = False
|
fp8: bool = False
|
||||||
|
# Force fp16 for all models in inference. No casting during inference.
|
||||||
|
# This flag is controlled by "--precision half" command line arg.
|
||||||
|
force_fp16: bool = False
|
||||||
device: torch.device = None
|
device: torch.device = None
|
||||||
device_interrogate: torch.device = None
|
device_interrogate: torch.device = None
|
||||||
device_gfpgan: torch.device = None
|
device_gfpgan: torch.device = None
|
||||||
|
@ -127,6 +130,8 @@ unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
def cond_cast_unet(input):
|
def cond_cast_unet(input):
|
||||||
|
if force_fp16:
|
||||||
|
return input.to(torch.float16)
|
||||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,6 +211,11 @@ def autocast(disable=False):
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
if force_fp16:
|
||||||
|
# No casting during inference if force_fp16 is enabled.
|
||||||
|
# All tensor dtype conversion happens before inference.
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if fp8 and device==cpu:
|
if fp8 and device==cpu:
|
||||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
|
@ -269,3 +279,17 @@ def first_time_calculation():
|
||||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
conv2d(x)
|
conv2d(x)
|
||||||
|
|
||||||
|
|
||||||
|
def force_model_fp16():
|
||||||
|
"""
|
||||||
|
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
|
||||||
|
force conversion of input to float32. If force_fp16 is enabled, we need to
|
||||||
|
prevent this casting.
|
||||||
|
"""
|
||||||
|
assert force_fp16
|
||||||
|
import sgm.modules.diffusionmodules.util as sgm_util
|
||||||
|
import ldm.modules.diffusionmodules.util as ldm_util
|
||||||
|
sgm_util.GroupNorm32 = torch.nn.GroupNorm
|
||||||
|
ldm_util.GroupNorm32 = torch.nn.GroupNorm
|
||||||
|
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")
|
||||||
|
|
|
@ -36,7 +36,7 @@ th = TorchHijackForUnet()
|
||||||
|
|
||||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
|
"""Always make sure inputs to unet are in correct dtype."""
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
for y in cond.keys():
|
for y in cond.keys():
|
||||||
if isinstance(cond[y], list):
|
if isinstance(cond[y], list):
|
||||||
|
@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||||
|
if devices.unet_needs_upcast:
|
||||||
|
return result.float()
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
|
@ -64,12 +68,11 @@ def hijack_ddpm_edit():
|
||||||
if not ddpm_edit_hijack:
|
if not ddpm_edit_hijack:
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||||
|
|
||||||
|
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
|
||||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
|
@ -81,5 +84,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
|
||||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||||
|
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = devices.dtype_unet
|
||||||
|
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
|
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
always_true_func = lambda *args, **kwargs: True
|
||||||
|
|
||||||
|
|
||||||
class CondFunc:
|
class CondFunc:
|
||||||
def __new__(cls, orig_func, sub_func, cond_func):
|
def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
|
||||||
self = super(CondFunc, cls).__new__(cls)
|
self = super(CondFunc, cls).__new__(cls)
|
||||||
if isinstance(orig_func, str):
|
if isinstance(orig_func, str):
|
||||||
func_path = orig_func.split('.')
|
func_path = orig_func.split('.')
|
||||||
|
|
|
@ -403,6 +403,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||||
model.float()
|
model.float()
|
||||||
model.alphas_cumprod_original = model.alphas_cumprod
|
model.alphas_cumprod_original = model.alphas_cumprod
|
||||||
devices.dtype_unet = torch.float32
|
devices.dtype_unet = torch.float32
|
||||||
|
assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
|
||||||
timer.record("apply float()")
|
timer.record("apply float()")
|
||||||
else:
|
else:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
|
@ -540,7 +541,7 @@ def repair_config(sd_config):
|
||||||
if hasattr(sd_config.model.params, 'unet_config'):
|
if hasattr(sd_config.model.params, 'unet_config'):
|
||||||
if shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
elif shared.cmd_opts.upcast_sampling:
|
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||||
|
|
|
@ -31,6 +31,14 @@ def initialize():
|
||||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||||
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
||||||
|
|
||||||
|
if cmd_opts.precision == "half":
|
||||||
|
msg = "--no-half and --no-half-vae conflict with --precision half"
|
||||||
|
assert devices.dtype == torch.float16, msg
|
||||||
|
assert devices.dtype_vae == torch.float16, msg
|
||||||
|
assert devices.dtype_inference == torch.float16, msg
|
||||||
|
devices.force_fp16 = True
|
||||||
|
devices.force_model_fp16()
|
||||||
|
|
||||||
shared.device = devices.device
|
shared.device = devices.device
|
||||||
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue