From 86221269f98ef9b21a6e6c9d04b86e2fb5cb33d3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 09:55:35 +0300 Subject: [PATCH 1/5] RAM optimization round 2 --- extensions-builtin/Lora/networks.py | 5 ++- modules/sd_disable_initialization.py | 51 ++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 9fca36b6a..96f935b23 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -304,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) - if weights_backup is None: + if weights_backup is None and wanted_names != (): + if current_names != (): + raise RuntimeError("no backup weights found and current weights are not unchanged") + if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 695c57362..719eeb93f 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -168,22 +168,59 @@ class LoadStateDictOnMeta(ReplaceHelper): device = self.device def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): - params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + used_param_keys = [] + + for name, param in self._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) - for name, param in params: if param.is_meta: - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + dtype = sd_param.dtype if sd_param is not None else param.dtype + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in self._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) original(self, state_dict, prefix, *args, **kwargs) - for name, _ in params: - key = prefix + name - if key in sd: - del sd[key] + for key in used_param_keys: + state_dict.pop(key, None) + def load_state_dict(original, self, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict == sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + original(self, state_dict, strict=strict) + + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) def __exit__(self, exc_type, exc_val, exc_tb): self.restore() From 0815c45bcdec0a2e5c60bdd5b33d95813d799c01 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 10:44:17 +0300 Subject: [PATCH 2/5] send weights to target device instead of CPU memory --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd60..b01d44c5d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") From 57e59c14c8a13a99d6422597d27d92ad10a51ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 11:28:00 +0300 Subject: [PATCH 3/5] Revert "send weights to target device instead of CPU memory" This reverts commit 0815c45bcdec0a2e5c60bdd5b33d95813d799c01. --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index b01d44c5d..f6fbdcd60 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: [PATCH 4/5] send weights to target device instead of CPU memory --- modules/sd_disable_initialization.py | 24 +++++++++++++++--------- modules/sd_models.py | 17 ++++++++++++++++- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 719eeb93f..8863107ae 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ``` """ - def __init__(self, state_dict, device): + def __init__(self, state_dict, device, weight_dtype_conversion=None): super().__init__() self.state_dict = state_dict self.device = device + self.weight_dtype_conversion = weight_dtype_conversion or {} + self.default_dtype = self.weight_dtype_conversion.get('') + + def get_weight_dtype(self, key): + key_first_term, _ = key.split('.', 1) + return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) def __enter__(self): if shared.cmd_opts.disable_model_loading_ram_optimization: @@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): sd = self.state_dict device = self.device - def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - for name, param in self._parameters.items(): + for name, param in module._parameters.items(): if param is None: continue key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) used_param_keys.append(key) if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - for name in self._buffers: + for name in module._buffers: key = prefix + name sd_param = sd.pop(key, None) @@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): state_dict[key] = sd_param used_param_keys.append(key) - original(self, state_dict, prefix, *args, **kwargs) + original(module, state_dict, prefix, *args, **kwargs) for key in used_param_keys: state_dict.pop(key, None) - def load_state_dict(original, self, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. @@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if state_dict == sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - original(self, state_dict, strict=strict) + original(module, state_dict, strict=strict) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd60..f912fe164 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc() +def model_target_device(): + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + return devices.cpu + else: + return devices.device + + def send_model_to_device(m): if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) @@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, + } + + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") From 0dc74545c0b5510911757ed9f2be703aab58f014 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 17 Aug 2023 07:54:07 +0300 Subject: [PATCH 5/5] resolve the issue with loading fp16 checkpoints while using --no-half --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index f912fe164..685585b1c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if not shared.cmd_opts.no_half: + if shared.cmd_opts.no_half: + model.float() + timer.record("apply float()") + else: vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None)