From 2abd89acc66419abf2eee9b03fd093f2737670de Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 20:04:35 +0300 Subject: [PATCH 1/3] index on master: 91c8d0d Merge pull request #7231 from EllangoK/master --- extensions-builtin/Lora/lora.py | 21 +++++++++++++++++-- .../Lora/scripts/lora_script.py | 5 +++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36a..568a76755 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -12,7 +12,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+) re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") -def convert_diffusers_name_to_compvis(key): +def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex): r = re.match(regex, key) if not r: @@ -34,6 +34,14 @@ def convert_diffusers_name_to_compvis(key): return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" if match(m, re_text_block): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + elif 'self_attn': + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key @@ -83,9 +91,10 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) keys_failed_to_match = [] + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) + fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) key, lora_key = fullkey.split(".", 1) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) @@ -104,9 +113,13 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' with torch.no_grad(): @@ -182,6 +195,10 @@ def lora_Conv2d_forward(self, input): return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) +def lora_NonDynamicallyQuantizableLinear_forward(self, input): + return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) + + def list_available_loras(): available_loras.clear() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160e..a385ae948 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -10,6 +10,7 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora def before_ui(): @@ -23,8 +24,12 @@ if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'NonDynamicallyQuantizableLinear_forward_before_lora'): + torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora = torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward + torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = lora.lora_NonDynamicallyQuantizableLinear_forward script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) From 80b26d2a69617b75d2d01c1e6b7d11445815ed4d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 25 Mar 2023 23:06:33 +0300 Subject: [PATCH 2/3] apply Lora by altering layer's weights instead of adding more calculations in forward() --- extensions-builtin/Lora/lora.py | 72 ++++++++++++++----- .../Lora/scripts/lora_script.py | 12 +++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb7..a737fec37 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -131,7 +131,7 @@ def load_lora(name, filename): with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): + """ + Applies the currently selected set of Loras to the weight of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) - return res + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + self.weight.copy_(weights_backup) + + lora_layer_name = getattr(self, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue + + with torch.no_grad(): + up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) + down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + setattr(self, "lora_current_names", wanted_names) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160e..dc329e812 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora def before_ui(): @@ -20,11 +22,19 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - })) From 650ddc9dd3c1d126221682be8270f7fba1b5b6ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 26 Mar 2023 10:44:20 +0300 Subject: [PATCH 3/3] Lora support for SD2 --- extensions-builtin/Lora/lora.py | 157 +++++++++++++----- .../Lora/scripts/lora_script.py | 10 ++ 2 files changed, 127 insertions(+), 40 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d4345ada6..edd95f786 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + r = re.match(regex, key) if not r: return False @@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_text_block): + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" @@ -109,16 +131,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) - keys_failed_to_match = [] + keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) - key, lora_key = fullkey.split(".", 1) + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: - keys_failed_to_match.append(key_diffusers) + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key continue lora_module = lora.modules.get(key, None) @@ -133,7 +161,9 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: @@ -190,54 +220,94 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): """ - Applies the currently selected set of Loras to the weight of torch layer self. + Applies the currently selected set of Loras to the weights of torch layer self. If weights already have this particular set of loras applied, does nothing. If not, restores orginal weights from backup and alters weights according to loras. """ + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + current_names = getattr(self, "lora_current_names", ()) wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) weights_backup = getattr(self, "lora_weights_backup", None) if weights_backup is None: - weights_backup = self.weight.to(devices.cpu, copy=True) + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup if current_names != wanted_names: if weights_backup is not None: - self.weight.copy_(weights_backup) + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) - lora_layer_name = getattr(self, 'lora_layer_name', None) for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) + continue + + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + if module is None: continue - with torch.no_grad(): - up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) - down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - else: - updown = up @ down - - self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + print(f'failed to calculate lora weights for layer {lora_layer_name}') setattr(self, "lora_current_names", wanted_names) +def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + def lora_Linear_forward(self, input): lora_apply_weights(self) return torch.nn.Linear_forward_before_lora(self, input) -def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) @@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input): return torch.nn.Conv2d_forward_before_lora(self, input) -def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) -def lora_NonDynamicallyQuantizableLinear_forward(self, input): - return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index dc329e812..0adab2254 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -12,6 +12,8 @@ def unload(): torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora def before_ui(): @@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload)