From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: [PATCH] send weights to target device instead of CPU memory --- modules/sd_disable_initialization.py | 24 +++++++++++++++--------- modules/sd_models.py | 17 ++++++++++++++++- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 719eeb93f..8863107ae 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ``` """ - def __init__(self, state_dict, device): + def __init__(self, state_dict, device, weight_dtype_conversion=None): super().__init__() self.state_dict = state_dict self.device = device + self.weight_dtype_conversion = weight_dtype_conversion or {} + self.default_dtype = self.weight_dtype_conversion.get('') + + def get_weight_dtype(self, key): + key_first_term, _ = key.split('.', 1) + return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) def __enter__(self): if shared.cmd_opts.disable_model_loading_ram_optimization: @@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): sd = self.state_dict device = self.device - def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - for name, param in self._parameters.items(): + for name, param in module._parameters.items(): if param is None: continue key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) used_param_keys.append(key) if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - for name in self._buffers: + for name in module._buffers: key = prefix + name sd_param = sd.pop(key, None) @@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): state_dict[key] = sd_param used_param_keys.append(key) - original(self, state_dict, prefix, *args, **kwargs) + original(module, state_dict, prefix, *args, **kwargs) for key in used_param_keys: state_dict.pop(key, None) - def load_state_dict(original, self, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. @@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if state_dict == sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - original(self, state_dict, strict=strict) + original(module, state_dict, strict=strict) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd60..f912fe164 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc() +def model_target_device(): + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + return devices.cpu + else: + return devices.device + + def send_model_to_device(m): if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) @@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, + } + + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict")