2024-07-07 23:17:51 -06:00
from __future__ import annotations
2023-12-30 01:11:03 -07:00
import gradio as gr
2023-08-13 06:07:37 -06:00
import logging
2023-07-16 14:13:55 -06:00
import os
import re
2023-08-15 10:23:27 -06:00
import lora_patches
2023-07-16 14:13:55 -06:00
import network
import network_lora
2023-10-11 22:26:58 -06:00
import network_glora
2023-07-16 14:13:55 -06:00
import network_hada
2023-07-16 15:12:18 -06:00
import network_ia3
2023-07-16 15:29:07 -06:00
import network_lokr
2023-07-17 00:00:47 -06:00
import network_full
2023-08-12 12:27:39 -06:00
import network_norm
2023-10-18 00:35:50 -06:00
import network_oft
2023-07-16 14:13:55 -06:00
import torch
from typing import Union
2023-07-18 11:11:30 -06:00
from modules import shared , devices , sd_models , errors , scripts , sd_hijack
2023-10-14 03:14:56 -06:00
import modules . textual_inversion . textual_inversion as textual_inversion
2023-07-16 14:13:55 -06:00
2023-10-10 00:44:20 -06:00
from lora_logger import logger
2023-07-16 14:13:55 -06:00
module_types = [
network_lora . ModuleTypeLora ( ) ,
network_hada . ModuleTypeHada ( ) ,
2023-07-16 15:12:18 -06:00
network_ia3 . ModuleTypeIa3 ( ) ,
2023-07-16 15:29:07 -06:00
network_lokr . ModuleTypeLokr ( ) ,
2023-07-17 00:00:47 -06:00
network_full . ModuleTypeFull ( ) ,
2023-08-12 12:27:39 -06:00
network_norm . ModuleTypeNorm ( ) ,
2023-10-11 22:26:58 -06:00
network_glora . ModuleTypeGLora ( ) ,
2023-10-18 00:35:50 -06:00
network_oft . ModuleTypeOFT ( ) ,
2023-07-16 14:13:55 -06:00
]
re_digits = re . compile ( r " \ d+ " )
re_x_proj = re . compile ( r " (.*)_([qkv]_proj)$ " )
re_compiled = { }
suffix_conversion = {
" attentions " : { } ,
" resnets " : {
" conv1 " : " in_layers_2 " ,
" conv2 " : " out_layers_3 " ,
2023-08-12 12:27:39 -06:00
" norm1 " : " in_layers_0 " ,
" norm2 " : " out_layers_0 " ,
2023-07-16 14:13:55 -06:00
" time_emb_proj " : " emb_layers_1 " ,
" conv_shortcut " : " skip_connection " ,
}
}
def convert_diffusers_name_to_compvis ( key , is_sd2 ) :
def match ( match_list , regex_text ) :
regex = re_compiled . get ( regex_text )
if regex is None :
regex = re . compile ( regex_text )
re_compiled [ regex_text ] = regex
r = re . match ( regex , key )
if not r :
return False
match_list . clear ( )
match_list . extend ( [ int ( x ) if re . match ( re_digits , x ) else x for x in r . groups ( ) ] )
return True
m = [ ]
2023-07-17 00:00:47 -06:00
if match ( m , r " lora_unet_conv_in(.*) " ) :
return f ' diffusion_model_input_blocks_0_0 { m [ 0 ] } '
if match ( m , r " lora_unet_conv_out(.*) " ) :
return f ' diffusion_model_out_2 { m [ 0 ] } '
if match ( m , r " lora_unet_time_embedding_linear_( \ d+)(.*) " ) :
return f " diffusion_model_time_embed_ { m [ 0 ] * 2 - 2 } { m [ 1 ] } "
2023-07-16 14:13:55 -06:00
if match ( m , r " lora_unet_down_blocks_( \ d+)_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 1 ] , { } ) . get ( m [ 3 ] , m [ 3 ] )
return f " diffusion_model_input_blocks_ { 1 + m [ 0 ] * 3 + m [ 2 ] } _ { 1 if m [ 1 ] == ' attentions ' else 0 } _ { suffix } "
if match ( m , r " lora_unet_mid_block_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 0 ] , { } ) . get ( m [ 2 ] , m [ 2 ] )
return f " diffusion_model_middle_block_ { 1 if m [ 0 ] == ' attentions ' else m [ 1 ] * 2 } _ { suffix } "
if match ( m , r " lora_unet_up_blocks_( \ d+)_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 1 ] , { } ) . get ( m [ 3 ] , m [ 3 ] )
return f " diffusion_model_output_blocks_ { m [ 0 ] * 3 + m [ 2 ] } _ { 1 if m [ 1 ] == ' attentions ' else 0 } _ { suffix } "
if match ( m , r " lora_unet_down_blocks_( \ d+)_downsamplers_0_conv " ) :
return f " diffusion_model_input_blocks_ { 3 + m [ 0 ] * 3 } _0_op "
if match ( m , r " lora_unet_up_blocks_( \ d+)_upsamplers_0_conv " ) :
return f " diffusion_model_output_blocks_ { 2 + m [ 0 ] * 3 } _ { 2 if m [ 0 ] > 0 else 1 } _conv "
if match ( m , r " lora_te_text_model_encoder_layers_( \ d+)_(.+) " ) :
if is_sd2 :
if ' mlp_fc1 ' in m [ 1 ] :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc1 ' , ' mlp_c_fc ' ) } "
elif ' mlp_fc2 ' in m [ 1 ] :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc2 ' , ' mlp_c_proj ' ) } "
else :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' self_attn ' , ' attn ' ) } "
return f " transformer_text_model_encoder_layers_ { m [ 0 ] } _ { m [ 1 ] } "
if match ( m , r " lora_te2_text_model_encoder_layers_( \ d+)_(.+) " ) :
if ' mlp_fc1 ' in m [ 1 ] :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc1 ' , ' mlp_c_fc ' ) } "
elif ' mlp_fc2 ' in m [ 1 ] :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc2 ' , ' mlp_c_proj ' ) } "
else :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' self_attn ' , ' attn ' ) } "
return key
def assign_network_names_to_compvis_modules ( sd_model ) :
network_layer_mapping = { }
if shared . sd_model . is_sdxl :
for i , embedder in enumerate ( shared . sd_model . conditioner . embedders ) :
if not hasattr ( embedder , ' wrapped ' ) :
continue
for name , module in embedder . wrapped . named_modules ( ) :
network_name = f ' { i } _ { name . replace ( " . " , " _ " ) } '
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
else :
2024-06-15 23:04:31 -06:00
cond_stage_model = getattr ( shared . sd_model . cond_stage_model , ' wrapped ' , shared . sd_model . cond_stage_model )
for name , module in cond_stage_model . named_modules ( ) :
2023-07-16 14:13:55 -06:00
network_name = name . replace ( " . " , " _ " )
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
for name , module in shared . sd_model . model . named_modules ( ) :
network_name = name . replace ( " . " , " _ " )
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
sd_model . network_layer_mapping = network_layer_mapping
2024-05-01 04:41:02 -06:00
class BundledTIHash ( str ) :
def __init__ ( self , hash_str ) :
self . hash = hash_str
def __str__ ( self ) :
return self . hash if shared . opts . lora_bundled_ti_to_infotext else ' '
2023-07-16 14:13:55 -06:00
def load_network ( name , network_on_disk ) :
net = network . Network ( name , network_on_disk )
net . mtime = os . path . getmtime ( network_on_disk . filename )
sd = sd_models . read_state_dict ( network_on_disk . filename )
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr ( shared . sd_model , ' network_layer_mapping ' ) :
assign_network_names_to_compvis_modules ( shared . sd_model )
keys_failed_to_match = { }
is_sd2 = ' model_transformer_resblocks ' in shared . sd_model . network_layer_mapping
matched_networks = { }
2023-10-09 08:52:09 -06:00
bundle_embeddings = { }
2023-07-16 14:13:55 -06:00
for key_network , weight in sd . items ( ) :
2023-12-08 13:19:29 -07:00
key_network_without_network_parts , _ , network_part = key_network . partition ( " . " )
2023-10-09 08:52:09 -06:00
if key_network_without_network_parts == " bundle_emb " :
emb_name , vec_name = network_part . split ( " . " , 1 )
emb_dict = bundle_embeddings . get ( emb_name , { } )
2023-10-09 22:09:33 -06:00
if vec_name . split ( ' . ' ) [ 0 ] == ' string_to_param ' :
_ , k2 = vec_name . split ( ' . ' , 1 )
emb_dict [ ' string_to_param ' ] = { k2 : weight }
else :
emb_dict [ vec_name ] = weight
2023-10-09 08:52:09 -06:00
bundle_embeddings [ emb_name ] = emb_dict
2023-07-16 14:13:55 -06:00
key = convert_diffusers_name_to_compvis ( key_network_without_network_parts , is_sd2 )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
if sd_module is None :
m = re_x_proj . match ( key )
if m :
sd_module = shared . sd_model . network_layer_mapping . get ( m . group ( 1 ) , None )
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and " lora_unet " in key_network_without_network_parts :
key = key_network_without_network_parts . replace ( " lora_unet " , " diffusion_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
elif sd_module is None and " lora_te1_text_model " in key_network_without_network_parts :
key = key_network_without_network_parts . replace ( " lora_te1_text_model " , " 0_transformer_text_model " )
2023-07-25 07:18:10 -06:00
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
# some SD1 Loras also have correct compvis keys
if sd_module is None :
key = key_network_without_network_parts . replace ( " lora_te1_text_model " , " transformer_text_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
2023-07-16 14:13:55 -06:00
2023-11-02 01:11:32 -06:00
# kohya_ss OFT module
2023-10-18 05:16:01 -06:00
elif sd_module is None and " oft_unet " in key_network_without_network_parts :
key = key_network_without_network_parts . replace ( " oft_unet " , " diffusion_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
2023-11-02 01:11:32 -06:00
# KohakuBlueLeaf OFT module
if sd_module is None and " oft_diag " in key :
key = key_network_without_network_parts . replace ( " lora_unet " , " diffusion_model " )
key = key_network_without_network_parts . replace ( " lora_te1_text_model " , " 0_transformer_text_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
2023-07-16 14:13:55 -06:00
if sd_module is None :
keys_failed_to_match [ key_network ] = key
continue
if key not in matched_networks :
matched_networks [ key ] = network . NetworkWeights ( network_key = key_network , sd_key = key , w = { } , sd_module = sd_module )
matched_networks [ key ] . w [ network_part ] = weight
for key , weights in matched_networks . items ( ) :
net_module = None
for nettype in module_types :
net_module = nettype . create_module ( net , weights )
if net_module is not None :
break
if net_module is None :
raise AssertionError ( f " Could not find a module type (out of { ' , ' . join ( [ x . __class__ . __name__ for x in module_types ] ) } ) that would accept those keys: { ' , ' . join ( weights . w ) } " )
net . modules [ key ] = net_module
2023-10-10 00:44:20 -06:00
embeddings = { }
for emb_name , data in bundle_embeddings . items ( ) :
2023-10-14 03:14:56 -06:00
embedding = textual_inversion . create_embedding_from_data ( data , emb_name , filename = network_on_disk . filename + " / " + emb_name )
2023-10-10 00:44:20 -06:00
embedding . loaded = None
2024-05-01 04:41:02 -06:00
embedding . shorthash = BundledTIHash ( name )
2023-10-10 00:44:20 -06:00
embeddings [ emb_name ] = embedding
net . bundle_embeddings = embeddings
2023-10-09 08:52:09 -06:00
2023-07-16 14:13:55 -06:00
if keys_failed_to_match :
2023-08-13 06:07:37 -06:00
logging . debug ( f " Network { network_on_disk . filename } didn ' t match keys: { keys_failed_to_match } " )
2023-07-16 14:13:55 -06:00
return net
2023-08-09 07:54:49 -06:00
def purge_networks_from_memory ( ) :
while len ( networks_in_memory ) > shared . opts . lora_in_memory_limit and len ( networks_in_memory ) > 0 :
name = next ( iter ( networks_in_memory ) )
networks_in_memory . pop ( name , None )
devices . torch_gc ( )
2023-07-17 00:00:47 -06:00
def load_networks ( names , te_multipliers = None , unet_multipliers = None , dyn_dims = None ) :
2023-10-09 08:52:09 -06:00
emb_db = sd_hijack . model_hijack . embedding_db
2023-07-16 14:13:55 -06:00
already_loaded = { }
for net in loaded_networks :
if net . name in names :
already_loaded [ net . name ] = net
2023-10-10 00:44:20 -06:00
for emb_name , embedding in net . bundle_embeddings . items ( ) :
if embedding . loaded :
emb_db . register_embedding_by_name ( None , shared . sd_model , emb_name )
2023-07-16 14:13:55 -06:00
loaded_networks . clear ( )
2024-06-04 01:02:13 -06:00
unavailable_networks = [ ]
for name in names :
if name . lower ( ) in forbidden_network_aliases and available_networks . get ( name ) is None :
unavailable_networks . append ( name )
elif available_network_aliases . get ( name ) is None :
unavailable_networks . append ( name )
if unavailable_networks :
update_available_networks_by_names ( unavailable_networks )
2024-01-15 12:45:19 -07:00
networks_on_disk = [ available_networks . get ( name , None ) if name . lower ( ) in forbidden_network_aliases else available_network_aliases . get ( name , None ) for name in names ]
2023-07-16 14:13:55 -06:00
if any ( x is None for x in networks_on_disk ) :
list_available_networks ( )
2024-01-15 12:45:19 -07:00
networks_on_disk = [ available_networks . get ( name , None ) if name . lower ( ) in forbidden_network_aliases else available_network_aliases . get ( name , None ) for name in names ]
2023-07-16 14:13:55 -06:00
failed_to_load_networks = [ ]
2023-08-09 07:54:49 -06:00
for i , ( network_on_disk , name ) in enumerate ( zip ( networks_on_disk , names ) ) :
2023-07-16 14:13:55 -06:00
net = already_loaded . get ( name , None )
if network_on_disk is not None :
2023-08-09 07:54:49 -06:00
if net is None :
net = networks_in_memory . get ( name )
2023-07-16 14:13:55 -06:00
if net is None or os . path . getmtime ( network_on_disk . filename ) > net . mtime :
try :
net = load_network ( name , network_on_disk )
2023-08-09 07:54:49 -06:00
networks_in_memory . pop ( name , None )
networks_in_memory [ name ] = net
2023-07-16 14:13:55 -06:00
except Exception as e :
errors . display ( e , f " loading network { network_on_disk . filename } " )
continue
net . mentioned_name = name
network_on_disk . read_hash ( )
if net is None :
failed_to_load_networks . append ( name )
2023-08-13 06:07:37 -06:00
logging . info ( f " Couldn ' t find network with name { name } " )
2023-07-16 14:13:55 -06:00
continue
2023-07-17 00:00:47 -06:00
net . te_multiplier = te_multipliers [ i ] if te_multipliers else 1.0
net . unet_multiplier = unet_multipliers [ i ] if unet_multipliers else 1.0
net . dyn_dim = dyn_dims [ i ] if dyn_dims else 1.0
2023-07-16 14:13:55 -06:00
loaded_networks . append ( net )
2023-10-10 00:44:20 -06:00
for emb_name , embedding in net . bundle_embeddings . items ( ) :
if embedding . loaded is None and emb_name in emb_db . word_embeddings :
logger . warning (
f ' Skip bundle embedding: " { emb_name } " '
' as it was already loaded from embeddings folder '
)
continue
2023-10-09 08:52:09 -06:00
2023-10-10 00:44:20 -06:00
embedding . loaded = False
2023-10-09 08:52:09 -06:00
if emb_db . expected_shape == - 1 or emb_db . expected_shape == embedding . shape :
2023-10-10 00:44:20 -06:00
embedding . loaded = True
2023-10-09 08:52:09 -06:00
emb_db . register_embedding ( embedding , shared . sd_model )
else :
emb_db . skipped_embeddings [ name ] = embedding
2023-07-16 14:13:55 -06:00
if failed_to_load_networks :
2023-12-30 01:11:03 -07:00
lora_not_found_message = f ' Lora not found: { " , " . join ( failed_to_load_networks ) } '
sd_hijack . model_hijack . comments . append ( lora_not_found_message )
if shared . opts . lora_not_found_warning_console :
print ( f ' \n { lora_not_found_message } \n ' )
if shared . opts . lora_not_found_gradio_warning :
gr . Warning ( lora_not_found_message )
2023-07-16 14:13:55 -06:00
2023-08-09 07:54:49 -06:00
purge_networks_from_memory ( )
2023-07-16 14:13:55 -06:00
2023-08-12 12:27:39 -06:00
def network_restore_weights_from_backup ( self : Union [ torch . nn . Conv2d , torch . nn . Linear , torch . nn . GroupNorm , torch . nn . LayerNorm , torch . nn . MultiheadAttention ] ) :
2023-07-16 14:13:55 -06:00
weights_backup = getattr ( self , " network_weights_backup " , None )
2023-08-12 12:27:39 -06:00
bias_backup = getattr ( self , " network_bias_backup " , None )
2023-07-16 14:13:55 -06:00
2023-08-12 12:27:39 -06:00
if weights_backup is None and bias_backup is None :
2023-07-16 14:13:55 -06:00
return
2023-08-12 12:27:39 -06:00
if weights_backup is not None :
if isinstance ( self , torch . nn . MultiheadAttention ) :
self . in_proj_weight . copy_ ( weights_backup [ 0 ] )
self . out_proj . weight . copy_ ( weights_backup [ 1 ] )
else :
self . weight . copy_ ( weights_backup )
2023-07-16 14:13:55 -06:00
2023-08-12 12:27:39 -06:00
if bias_backup is not None :
2023-08-13 23:32:51 -06:00
if isinstance ( self , torch . nn . MultiheadAttention ) :
self . out_proj . bias . copy_ ( bias_backup )
else :
self . bias . copy_ ( bias_backup )
else :
if isinstance ( self , torch . nn . MultiheadAttention ) :
self . out_proj . bias = None
else :
self . bias = None
2023-07-16 14:13:55 -06:00
2023-08-12 12:27:39 -06:00
def network_apply_weights ( self : Union [ torch . nn . Conv2d , torch . nn . Linear , torch . nn . GroupNorm , torch . nn . LayerNorm , torch . nn . MultiheadAttention ] ) :
2023-07-16 14:13:55 -06:00
"""
Applies the currently selected set of networks to the weights of torch layer self .
If weights already have this particular set of networks applied , does nothing .
2024-03-03 23:37:23 -07:00
If not , restores original weights from backup and alters weights according to networks .
2023-07-16 14:13:55 -06:00
"""
network_layer_name = getattr ( self , ' network_layer_name ' , None )
if network_layer_name is None :
return
current_names = getattr ( self , " network_current_names " , ( ) )
2023-07-17 00:00:47 -06:00
wanted_names = tuple ( ( x . name , x . te_multiplier , x . unet_multiplier , x . dyn_dim ) for x in loaded_networks )
2023-07-16 14:13:55 -06:00
weights_backup = getattr ( self , " network_weights_backup " , None )
2023-08-16 00:55:35 -06:00
if weights_backup is None and wanted_names != ( ) :
if current_names != ( ) :
raise RuntimeError ( " no backup weights found and current weights are not unchanged " )
2023-07-16 14:13:55 -06:00
if isinstance ( self , torch . nn . MultiheadAttention ) :
weights_backup = ( self . in_proj_weight . to ( devices . cpu , copy = True ) , self . out_proj . weight . to ( devices . cpu , copy = True ) )
else :
weights_backup = self . weight . to ( devices . cpu , copy = True )
self . network_weights_backup = weights_backup
2023-08-12 12:27:39 -06:00
bias_backup = getattr ( self , " network_bias_backup " , None )
2024-05-16 09:39:01 -06:00
if bias_backup is None and wanted_names != ( ) :
2023-08-13 23:32:51 -06:00
if isinstance ( self , torch . nn . MultiheadAttention ) and self . out_proj . bias is not None :
bias_backup = self . out_proj . bias . to ( devices . cpu , copy = True )
elif getattr ( self , ' bias ' , None ) is not None :
bias_backup = self . bias . to ( devices . cpu , copy = True )
else :
bias_backup = None
2024-05-16 12:45:00 -06:00
# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ( ) :
raise RuntimeError ( " no backup bias found and current bias are not unchanged " )
2023-08-12 12:27:39 -06:00
self . network_bias_backup = bias_backup
2023-07-16 14:13:55 -06:00
if current_names != wanted_names :
network_restore_weights_from_backup ( self )
for net in loaded_networks :
module = net . modules . get ( network_layer_name , None )
if module is not None and hasattr ( self , ' weight ' ) :
2023-08-13 06:07:37 -06:00
try :
with torch . no_grad ( ) :
2023-11-21 04:59:34 -07:00
if getattr ( self , ' fp16_weight ' , None ) is None :
weight = self . weight
bias = self . bias
else :
weight = self . fp16_weight . clone ( ) . to ( self . weight . device )
bias = getattr ( self , ' fp16_bias ' , None )
if bias is not None :
bias = bias . clone ( ) . to ( self . bias . device )
updown , ex_bias = module . calc_updown ( weight )
if len ( weight . shape ) == 4 and weight . shape [ 1 ] == 9 :
2023-08-13 06:07:37 -06:00
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch . nn . functional . pad ( updown , ( 0 , 0 , 0 , 0 , 0 , 5 ) )
2023-07-16 14:13:55 -06:00
2023-11-21 04:59:34 -07:00
self . weight . copy_ ( ( weight . to ( dtype = updown . dtype ) + updown ) . to ( dtype = self . weight . dtype ) )
2023-08-13 23:34:51 -06:00
if ex_bias is not None and hasattr ( self , ' bias ' ) :
if self . bias is None :
2023-10-18 23:56:17 -06:00
self . bias = torch . nn . Parameter ( ex_bias ) . to ( self . weight . dtype )
2023-08-13 23:34:51 -06:00
else :
2023-11-21 04:59:34 -07:00
self . bias . copy_ ( ( bias + ex_bias ) . to ( dtype = self . bias . dtype ) )
2023-08-13 06:07:37 -06:00
except RuntimeError as e :
logging . debug ( f " Network { net . name } layer { network_layer_name } : { e } " )
extra_network_lora . errors [ net . name ] = extra_network_lora . errors . get ( net . name , 0 ) + 1
2023-07-16 14:13:55 -06:00
2023-08-13 06:07:37 -06:00
continue
2023-07-16 14:13:55 -06:00
module_q = net . modules . get ( network_layer_name + " _q_proj " , None )
module_k = net . modules . get ( network_layer_name + " _k_proj " , None )
module_v = net . modules . get ( network_layer_name + " _v_proj " , None )
module_out = net . modules . get ( network_layer_name + " _out_proj " , None )
if isinstance ( self , torch . nn . MultiheadAttention ) and module_q and module_k and module_v and module_out :
2023-08-13 06:07:37 -06:00
try :
with torch . no_grad ( ) :
2024-03-08 21:31:32 -07:00
# Send "real" orig_weight into MHA's lora module
qw , kw , vw = self . in_proj_weight . chunk ( 3 , 0 )
updown_q , _ = module_q . calc_updown ( qw )
updown_k , _ = module_k . calc_updown ( kw )
updown_v , _ = module_v . calc_updown ( vw )
del qw , kw , vw
2023-08-13 06:07:37 -06:00
updown_qkv = torch . vstack ( [ updown_q , updown_k , updown_v ] )
2023-08-13 23:34:51 -06:00
updown_out , ex_bias = module_out . calc_updown ( self . out_proj . weight )
2023-08-13 06:07:37 -06:00
self . in_proj_weight + = updown_qkv
self . out_proj . weight + = updown_out
2023-08-13 23:32:51 -06:00
if ex_bias is not None :
if self . out_proj . bias is None :
self . out_proj . bias = torch . nn . Parameter ( ex_bias )
else :
self . out_proj . bias + = ex_bias
2023-08-13 06:07:37 -06:00
except RuntimeError as e :
logging . debug ( f " Network { net . name } layer { network_layer_name } : { e } " )
extra_network_lora . errors [ net . name ] = extra_network_lora . errors . get ( net . name , 0 ) + 1
continue
2023-07-16 14:13:55 -06:00
if module is None :
continue
2023-08-13 06:07:37 -06:00
logging . debug ( f " Network { net . name } layer { network_layer_name } : couldn ' t find supported operation " )
extra_network_lora . errors [ net . name ] = extra_network_lora . errors . get ( net . name , 0 ) + 1
2023-07-16 14:13:55 -06:00
self . network_current_names = wanted_names
2024-01-05 01:32:19 -07:00
def network_forward ( org_module , input , original_forward ) :
2023-07-16 14:13:55 -06:00
"""
Old way of applying Lora by executing operations during layer ' s forward.
Stacking many loras this way results in big performance degradation .
"""
if len ( loaded_networks ) == 0 :
2024-01-05 01:32:19 -07:00
return original_forward ( org_module , input )
2023-07-16 14:13:55 -06:00
input = devices . cond_cast_unet ( input )
2024-01-05 01:32:19 -07:00
network_restore_weights_from_backup ( org_module )
network_reset_cached_weight ( org_module )
2023-07-16 14:13:55 -06:00
2024-01-05 01:32:19 -07:00
y = original_forward ( org_module , input )
2023-07-16 14:13:55 -06:00
2024-01-05 01:32:19 -07:00
network_layer_name = getattr ( org_module , ' network_layer_name ' , None )
2023-07-16 14:13:55 -06:00
for lora in loaded_networks :
module = lora . modules . get ( network_layer_name , None )
if module is None :
continue
2023-08-10 21:42:58 -06:00
y = module . forward ( input , y )
2023-07-16 14:13:55 -06:00
return y
def network_reset_cached_weight ( self : Union [ torch . nn . Conv2d , torch . nn . Linear ] ) :
self . network_current_names = ( )
self . network_weights_backup = None
2023-09-10 03:53:42 -06:00
self . network_bias_backup = None
2023-07-16 14:13:55 -06:00
def network_Linear_forward ( self , input ) :
if shared . opts . lora_functional :
2023-08-15 10:23:27 -06:00
return network_forward ( self , input , originals . Linear_forward )
2023-07-16 14:13:55 -06:00
network_apply_weights ( self )
2023-08-15 10:23:27 -06:00
return originals . Linear_forward ( self , input )
2023-07-16 14:13:55 -06:00
def network_Linear_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
2023-08-15 10:23:27 -06:00
return originals . Linear_load_state_dict ( self , * args , * * kwargs )
2023-07-16 14:13:55 -06:00
def network_Conv2d_forward ( self , input ) :
if shared . opts . lora_functional :
2023-08-15 10:23:27 -06:00
return network_forward ( self , input , originals . Conv2d_forward )
2023-07-16 14:13:55 -06:00
network_apply_weights ( self )
2023-08-15 10:23:27 -06:00
return originals . Conv2d_forward ( self , input )
2023-07-16 14:13:55 -06:00
def network_Conv2d_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
2023-08-15 10:23:27 -06:00
return originals . Conv2d_load_state_dict ( self , * args , * * kwargs )
2023-07-16 14:13:55 -06:00
2023-08-12 12:27:39 -06:00
def network_GroupNorm_forward ( self , input ) :
if shared . opts . lora_functional :
2023-08-15 10:23:27 -06:00
return network_forward ( self , input , originals . GroupNorm_forward )
2023-08-12 12:27:39 -06:00
network_apply_weights ( self )
2023-08-15 10:23:27 -06:00
return originals . GroupNorm_forward ( self , input )
2023-08-12 12:27:39 -06:00
def network_GroupNorm_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
2023-08-15 10:23:27 -06:00
return originals . GroupNorm_load_state_dict ( self , * args , * * kwargs )
2023-08-12 12:27:39 -06:00
def network_LayerNorm_forward ( self , input ) :
if shared . opts . lora_functional :
2023-08-15 10:23:27 -06:00
return network_forward ( self , input , originals . LayerNorm_forward )
2023-08-12 12:27:39 -06:00
network_apply_weights ( self )
2023-08-15 10:23:27 -06:00
return originals . LayerNorm_forward ( self , input )
2023-08-12 12:27:39 -06:00
def network_LayerNorm_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
2023-08-15 10:23:27 -06:00
return originals . LayerNorm_load_state_dict ( self , * args , * * kwargs )
2023-08-12 12:27:39 -06:00
2023-07-16 14:13:55 -06:00
def network_MultiheadAttention_forward ( self , * args , * * kwargs ) :
network_apply_weights ( self )
2023-08-15 10:23:27 -06:00
return originals . MultiheadAttention_forward ( self , * args , * * kwargs )
2023-07-16 14:13:55 -06:00
def network_MultiheadAttention_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
2023-08-15 10:23:27 -06:00
return originals . MultiheadAttention_load_state_dict ( self , * args , * * kwargs )
2023-07-16 14:13:55 -06:00
2024-06-04 01:02:13 -06:00
def process_network_files ( names : list [ str ] | None = None ) :
2023-07-16 14:13:55 -06:00
candidates = list ( shared . walk_files ( shared . cmd_opts . lora_dir , allowed_extensions = [ " .pt " , " .ckpt " , " .safetensors " ] ) )
2023-07-18 11:11:30 -06:00
candidates + = list ( shared . walk_files ( shared . cmd_opts . lyco_dir_backcompat , allowed_extensions = [ " .pt " , " .ckpt " , " .safetensors " ] ) )
2023-07-16 14:13:55 -06:00
for filename in candidates :
if os . path . isdir ( filename ) :
continue
name = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
2024-06-04 01:02:13 -06:00
# if names is provided, only load networks with names in the list
if names and name not in names :
continue
2023-07-16 14:13:55 -06:00
try :
entry = network . NetworkOnDisk ( name , filename )
except OSError : # should catch FileNotFoundError and PermissionError etc.
errors . report ( f " Failed to load network { name } from { filename } " , exc_info = True )
continue
available_networks [ name ] = entry
if entry . alias in available_network_aliases :
forbidden_network_aliases [ entry . alias . lower ( ) ] = 1
available_network_aliases [ name ] = entry
available_network_aliases [ entry . alias ] = entry
2024-06-04 01:02:13 -06:00
def update_available_networks_by_names ( names : list [ str ] ) :
process_network_files ( names )
def list_available_networks ( ) :
available_networks . clear ( )
available_network_aliases . clear ( )
forbidden_network_aliases . clear ( )
available_network_hash_lookup . clear ( )
forbidden_network_aliases . update ( { " none " : 1 , " Addams " : 1 } )
os . makedirs ( shared . cmd_opts . lora_dir , exist_ok = True )
process_network_files ( )
2023-07-16 14:13:55 -06:00
re_network_name = re . compile ( r " (.*) \ s* \ ([0-9a-fA-F]+ \ ) " )
def infotext_pasted ( infotext , params ) :
if " AddNet Module 1 " in [ x [ 1 ] for x in scripts . scripts_txt2img . infotext_fields ] :
return # if the other extension is active, it will handle those fields, no need to do anything
added = [ ]
for k in params :
if not k . startswith ( " AddNet Model " ) :
continue
num = k [ 13 : ]
if params . get ( " AddNet Module " + num ) != " LoRA " :
continue
name = params . get ( " AddNet Model " + num )
if name is None :
continue
m = re_network_name . match ( name )
if m :
name = m . group ( 1 )
multiplier = params . get ( " AddNet Weight A " + num , " 1.0 " )
added . append ( f " <lora: { name } : { multiplier } > " )
if added :
params [ " Prompt " ] + = " \n " + " " . join ( added )
2023-08-15 10:23:27 -06:00
originals : lora_patches . LoraPatches = None
2023-08-13 06:07:37 -06:00
extra_network_lora = None
2023-07-16 14:13:55 -06:00
available_networks = { }
available_network_aliases = { }
loaded_networks = [ ]
2023-10-09 08:52:09 -06:00
loaded_bundle_embeddings = { }
2023-08-09 07:54:49 -06:00
networks_in_memory = { }
2023-07-16 14:13:55 -06:00
available_network_hash_lookup = { }
forbidden_network_aliases = { }
list_available_networks ( )