Merge pull request #14031 from AUTOMATIC1111/test-fp8
A big improvement for dtype casting system with fp8 storage type and manual cast
This commit is contained in:
commit
c121f8c315
|
@ -137,7 +137,7 @@ class NetworkModule:
|
|||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
|
|
|
@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
|
|||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.weight.to(orig_weight.device)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
|
|
@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
|
|||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
|
|
@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
|
|||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t1 = self.t1.to(orig_weight.device)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
|
@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
|
|||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
|
|
|
@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
|
|||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w = self.w.to(orig_weight.device)
|
||||
|
||||
output_shape = [w.size(0), orig_weight.size(1)]
|
||||
if self.on_input:
|
||||
|
|
|
@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
|
|||
|
||||
def calc_updown(self, orig_weight):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1 = self.w1.to(orig_weight.device)
|
||||
else:
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w1 = w1a @ w1b
|
||||
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2 = self.w2.to(orig_weight.device)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
|
|
|
@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
|
|||
return module
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
up = self.up_model.weight.to(orig_weight.device)
|
||||
down = self.down_model.weight.to(orig_weight.device)
|
||||
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
mid = self.mid_model.weight.to(orig_weight.device)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
|
|
|
@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
|
|||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.w_norm.to(orig_weight.device)
|
||||
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.b_norm.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
|
|
@ -389,18 +389,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
if module is not None and hasattr(self, 'weight'):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
if getattr(self, 'fp16_weight', None) is None:
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
else:
|
||||
weight = self.fp16_weight.clone().to(self.weight.device)
|
||||
bias = getattr(self, 'fp16_bias', None)
|
||||
if bias is not None:
|
||||
bias = bias.clone().to(self.bias.device)
|
||||
updown, ex_bias = module.calc_updown(weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
if len(weight.shape) == 4 and weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
|
|
@ -23,6 +23,23 @@ def has_mps() -> bool:
|
|||
return mac_specific.has_mps
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
@ -73,8 +90,7 @@ def enable_tf32():
|
|||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||
if cuda_no_autocast():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -84,6 +100,7 @@ def enable_tf32():
|
|||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
|
@ -104,12 +121,51 @@ def cond_cast_float(input):
|
|||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(self, *args, **kwargs):
|
||||
org_dtype = next(self.parameters()).dtype
|
||||
self.to(dtype)
|
||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
self.to(org_dtype)
|
||||
return result
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast():
|
||||
for module_type in patch_module_list:
|
||||
org_forward = module_type.forward
|
||||
module_type.forward = manual_cast_forward
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
for module_type in patch_module_list:
|
||||
module_type.forward = module_type.org_forward
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
|
||||
return manual_cast()
|
||||
|
||||
if has_mps() and shared.cmd_opts.precision != "full":
|
||||
return manual_cast()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
|
|
@ -314,6 +314,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
if "FP8 weight" not in res:
|
||||
res["FP8 weight"] = "Disable"
|
||||
|
||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
skip = set(shared.opts.infotext_skip_pasting)
|
||||
res = {k: v for k, v in res.items() if k not in skip}
|
||||
|
||||
|
|
|
@ -177,6 +177,8 @@ def configure_opts_onchange():
|
|||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
|
||||
|
|
|
@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
||||
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
|
|
|
@ -348,10 +348,28 @@ class SkipWritingToConfig:
|
|||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
if model is None:
|
||||
return None
|
||||
if devices.get_optimal_device_name() == "mps":
|
||||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
|
@ -404,6 +422,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
@ -746,7 +786,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
@ -758,11 +798,14 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
if check_fp8(sd_model) != devices.fp8:
|
||||
# load from state dict again to prevent extra numerical errors
|
||||
forced_reload = True
|
||||
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
|
|
|
@ -93,7 +93,7 @@ def extend_sdxl(model):
|
|||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
|
|
@ -206,6 +206,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
|||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
|
|
|
@ -270,6 +270,7 @@ axis_options = [
|
|||
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
|
||||
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
|
||||
AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue