store patches for Lora in a specialized module
This commit is contained in:
parent
7327be97aa
commit
f01682ee01
|
@ -0,0 +1,31 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import networks
|
||||||
|
from modules import patches
|
||||||
|
|
||||||
|
|
||||||
|
class LoraPatches:
|
||||||
|
def __init__(self):
|
||||||
|
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
|
||||||
|
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
|
||||||
|
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
|
||||||
|
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
|
||||||
|
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
|
||||||
|
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
|
||||||
|
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
|
||||||
|
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
|
||||||
|
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
|
||||||
|
|
||||||
|
def undo(self):
|
||||||
|
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
|
||||||
|
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
|
||||||
|
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
|
||||||
|
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
|
||||||
|
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
|
||||||
|
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
|
||||||
|
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
|
||||||
|
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
|
||||||
|
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
|
||||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
import network_hada
|
import network_hada
|
||||||
|
@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
|
|
||||||
def network_Linear_forward(self, input):
|
def network_Linear_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
return network_forward(self, input, originals.Linear_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_network(self, input)
|
return originals.Linear_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Linear_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_forward(self, input):
|
def network_Conv2d_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
return network_forward(self, input, originals.Conv2d_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_network(self, input)
|
return originals.Conv2d_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_GroupNorm_forward(self, input):
|
def network_GroupNorm_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
|
return network_forward(self, input, originals.GroupNorm_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.GroupNorm_forward_before_network(self, input)
|
return originals.GroupNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_LayerNorm_forward(self, input):
|
def network_LayerNorm_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
|
return network_forward(self, input, originals.LayerNorm_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.LayerNorm_forward_before_network(self, input)
|
return originals.LayerNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_forward(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_networks():
|
def list_available_networks():
|
||||||
|
@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
|
||||||
if added:
|
if added:
|
||||||
params["Prompt"] += "\n" + "".join(added)
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
|
||||||
|
originals: lora_patches.LoraPatches = None
|
||||||
|
|
||||||
extra_network_lora = None
|
extra_network_lora = None
|
||||||
|
|
||||||
available_networks = {}
|
available_networks = {}
|
||||||
|
|
|
@ -7,17 +7,14 @@ from fastapi import FastAPI
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
import lora # noqa:F401
|
import lora # noqa:F401
|
||||||
|
import lora_patches
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
networks.originals.undo()
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
|
||||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
|
@ -28,46 +25,7 @@ def before_ui():
|
||||||
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
networks.originals = lora_patches.LoraPatches()
|
||||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
|
||||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
|
||||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
|
||||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
|
|
||||||
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
|
|
||||||
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
|
|
||||||
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
|
|
||||||
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
|
||||||
|
|
||||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
|
||||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
|
||||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
|
||||||
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
|
|
||||||
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
|
|
||||||
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
|
|
||||||
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
|
|
||||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def patch(key, obj, field, replacement):
|
||||||
|
"""Replaces a function in a module or a class.
|
||||||
|
|
||||||
|
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||||
|
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
replacement: the new function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the original function
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
if patch_key in originals[key]:
|
||||||
|
raise RuntimeError(f"patch for {field} is already applied")
|
||||||
|
|
||||||
|
original_func = getattr(obj, field)
|
||||||
|
originals[key][patch_key] = original_func
|
||||||
|
|
||||||
|
setattr(obj, field, replacement)
|
||||||
|
|
||||||
|
return original_func
|
||||||
|
|
||||||
|
|
||||||
|
def undo(key, obj, field):
|
||||||
|
"""Undoes the peplacement by the patch().
|
||||||
|
|
||||||
|
If the function is not replaced, raises an exception.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always None
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
if patch_key not in originals[key]:
|
||||||
|
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||||
|
|
||||||
|
original_func = originals[key].pop(patch_key)
|
||||||
|
setattr(obj, field, original_func)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def original(key, obj, field):
|
||||||
|
"""Returns the original function for the patch created by the patch() function"""
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
return originals[key].get(patch_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
originals = defaultdict(dict)
|
Loading…
Reference in New Issue