Merge pull request #14031 from AUTOMATIC1111/test-fp8

A big improvement for dtype casting system with fp8 storage type and manual cast
This commit is contained in:
AUTOMATIC1111 2023-12-16 10:22:51 +03:00 committed by GitHub
commit c121f8c315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 160 additions and 40 deletions

View File

@ -137,7 +137,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:

View File

@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None

View File

@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)

View File

@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)

View File

@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:

View File

@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
w1 = self.w1.to(orig_weight.device)
else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]

View File

@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:

View File

@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None

View File

@ -389,18 +389,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias += ex_bias
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

View File

@ -23,6 +23,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@ -73,8 +90,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -84,6 +100,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@ -104,12 +121,51 @@ def cond_cast_float(input):
nv_rng = None
patch_module_list = [
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
@contextlib.contextmanager
def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
module_type.org_forward = org_forward
try:
yield None
finally:
for module_type in patch_module_list:
module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_cast()
if has_mps() and shared.cmd_opts.precision != "full":
return manual_cast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()

View File

@ -314,6 +314,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
if "FP8 weight" not in res:
res["FP8 weight"] = "Disable"
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
res["Cache FP16 weight for LoRA"] = False
skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}

View File

@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")

View File

@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),

View File

@ -348,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
def check_fp8(model):
if model is None:
return None
if devices.get_optimal_device_name() == "mps":
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
return enable_fp8
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if devices.fp8:
# prevent model to load state dict in fp8
model.half()
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@ -404,6 +422,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@ -746,7 +786,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
def reload_model_weights(sd_model=None, info=None):
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@ -758,11 +798,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
if check_fp8(sd_model) != devices.fp8:
# load from state dict again to prevent extra numerical errors
forced_reload = True
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:

View File

@ -93,7 +93,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()

View File

@ -206,6 +206,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {

View File

@ -270,6 +270,7 @@ axis_options = [
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
]