initial work on SD2 Lora support
This commit is contained in:
commit
7cb31a278e
|
@ -12,7 +12,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)
|
||||||
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_name_to_compvis(key):
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||||
def match(match_list, regex):
|
def match(match_list, regex):
|
||||||
r = re.match(regex, key)
|
r = re.match(regex, key)
|
||||||
if not r:
|
if not r:
|
||||||
|
@ -34,6 +34,14 @@ def convert_diffusers_name_to_compvis(key):
|
||||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
||||||
|
|
||||||
if match(m, re_text_block):
|
if match(m, re_text_block):
|
||||||
|
if is_sd2:
|
||||||
|
if 'mlp_fc1' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
|
elif 'mlp_fc2' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||||
|
else:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||||
|
|
||||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
@ -83,9 +91,10 @@ def load_lora(name, filename):
|
||||||
sd = sd_models.read_state_dict(filename)
|
sd = sd_models.read_state_dict(filename)
|
||||||
|
|
||||||
keys_failed_to_match = []
|
keys_failed_to_match = []
|
||||||
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||||
|
|
||||||
for key_diffusers, weight in sd.items():
|
for key_diffusers, weight in sd.items():
|
||||||
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
|
||||||
key, lora_key = fullkey.split(".", 1)
|
key, lora_key = fullkey.split(".", 1)
|
||||||
|
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
@ -104,9 +113,13 @@ def load_lora(name, filename):
|
||||||
|
|
||||||
if type(sd_module) == torch.nn.Linear:
|
if type(sd_module) == torch.nn.Linear:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||||
|
module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d:
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -182,6 +195,10 @@ def lora_Conv2d_forward(self, input):
|
||||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
||||||
|
|
||||||
|
|
||||||
|
def lora_NonDynamicallyQuantizableLinear_forward(self, input):
|
||||||
|
return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input))
|
||||||
|
|
||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||||
|
torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
|
@ -23,8 +24,12 @@ if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'NonDynamicallyQuantizableLinear_forward_before_lora'):
|
||||||
|
torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora = torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward
|
||||||
|
|
||||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||||
|
torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = lora.lora_NonDynamicallyQuantizableLinear_forward
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
|
|
Loading…
Reference in New Issue