2022-09-03 03:08:45 -06:00
import torch
2022-10-02 15:31:19 -06:00
from torch . nn . functional import silu
2023-01-12 07:03:46 -07:00
from types import MethodType
2022-09-03 03:08:45 -06:00
2022-10-02 06:03:39 -06:00
import modules . textual_inversion . textual_inversion
2022-12-09 23:17:39 -07:00
from modules import devices , sd_hijack_optimizations , shared , sd_hijack_checkpoint
2022-11-26 06:45:57 -07:00
from modules . hypernetworks import hypernetwork
2022-12-09 23:17:39 -07:00
from modules . shared import cmd_opts
2022-12-31 08:06:35 -07:00
from modules import sd_hijack_clip , sd_hijack_open_clip , sd_hijack_unet , sd_hijack_xlmr , xlmr
2022-11-26 06:10:46 -07:00
2022-09-04 16:41:20 -06:00
import ldm . modules . attention
2022-09-13 05:29:56 -06:00
import ldm . modules . diffusionmodules . model
2022-12-02 05:47:02 -07:00
import ldm . modules . diffusionmodules . openaimodel
2022-11-11 08:20:18 -07:00
import ldm . models . diffusion . ddim
import ldm . models . diffusion . plms
2022-11-26 06:10:46 -07:00
import ldm . modules . encoders . modules
2022-09-13 05:29:56 -06:00
2022-10-02 06:03:39 -06:00
attention_CrossAttention_forward = ldm . modules . attention . CrossAttention . forward
diffusionmodules_model_nonlinearity = ldm . modules . diffusionmodules . model . nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm . modules . diffusionmodules . model . AttnBlock . forward
2022-09-13 05:29:56 -06:00
2022-11-26 06:10:46 -07:00
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm . modules . attention . MemoryEfficientCrossAttention = ldm . modules . attention . CrossAttention
ldm . modules . attention . BasicTransformerBlock . ATTENTION_MODES [ " softmax-xformers " ] = ldm . modules . attention . CrossAttention
# silence new console spam from SD2
ldm . modules . attention . print = lambda * args : None
ldm . modules . diffusionmodules . model . print = lambda * args : None
2022-10-15 07:59:37 -06:00
2022-12-09 23:14:30 -07:00
2022-10-02 06:03:39 -06:00
def apply_optimizations ( ) :
2022-10-07 07:39:51 -06:00
undo_optimizations ( )
2022-10-02 15:31:19 -06:00
ldm . modules . diffusionmodules . model . nonlinearity = silu
2022-12-09 23:14:30 -07:00
ldm . modules . diffusionmodules . openaimodel . th = sd_hijack_unet . th
2023-01-04 06:04:38 -07:00
optimization_method = None
2022-09-13 05:29:56 -06:00
2023-03-10 00:58:10 -07:00
can_use_sdp = hasattr ( torch . nn . functional , " scaled_dot_product_attention " ) and callable ( getattr ( torch . nn . functional , " scaled_dot_product_attention " ) ) # not everyone has torch 2.x to use sdp
2022-10-15 10:19:54 -06:00
if cmd_opts . force_enable_xformers or ( cmd_opts . xformers and shared . xformers_available and torch . version . cuda and ( 6 , 0 ) < = torch . cuda . get_device_capability ( shared . device ) < = ( 9 , 0 ) ) :
2022-10-08 10:22:15 -06:00
print ( " Applying xformers cross attention optimization. " )
2022-10-08 08:44:53 -06:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . xformers_attention_forward
2022-10-17 13:19:18 -06:00
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . xformers_attnblock_forward
2023-01-04 06:04:38 -07:00
optimization_method = ' xformers '
2023-03-10 00:58:10 -07:00
elif cmd_opts . opt_sdp_no_mem_attention and can_use_sdp :
print ( " Applying scaled dot product cross attention optimization (without memory efficient attention). " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . scaled_dot_product_no_mem_attention_forward
optimization_method = ' sdp-no-mem '
elif cmd_opts . opt_sdp_attention and can_use_sdp :
print ( " Applying scaled dot product cross attention optimization. " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . scaled_dot_product_attention_forward
optimization_method = ' sdp '
2022-12-27 06:50:55 -07:00
elif cmd_opts . opt_sub_quad_attention :
print ( " Applying sub-quadratic cross attention optimization. " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . sub_quad_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . sub_quad_attnblock_forward
2023-01-04 21:10:31 -07:00
optimization_method = ' sub-quadratic '
2022-10-07 19:10:35 -06:00
elif cmd_opts . opt_split_attention_v1 :
2022-10-08 10:22:15 -06:00
print ( " Applying v1 cross attention optimization. " )
2022-10-02 06:03:39 -06:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward_v1
2023-01-04 06:04:38 -07:00
optimization_method = ' V1 '
2023-01-05 23:33:15 -07:00
elif not cmd_opts . disable_opt_split_attention and ( cmd_opts . opt_split_attention_invokeai or not cmd_opts . opt_split_attention and not torch . cuda . is_available ( ) ) :
2022-12-27 06:50:55 -07:00
print ( " Applying cross attention optimization (InvokeAI). " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward_invokeAI
2023-01-04 21:10:31 -07:00
optimization_method = ' InvokeAI '
2022-10-02 06:03:39 -06:00
elif not cmd_opts . disable_opt_split_attention and ( cmd_opts . opt_split_attention or torch . cuda . is_available ( ) ) :
2022-10-10 20:48:54 -06:00
print ( " Applying cross attention optimization (Doggettx). " )
2022-10-02 06:03:39 -06:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . cross_attention_attnblock_forward
2023-01-04 06:04:38 -07:00
optimization_method = ' Doggettx '
return optimization_method
2022-09-13 05:29:56 -06:00
2022-10-02 06:03:39 -06:00
def undo_optimizations ( ) :
2022-11-26 06:45:57 -07:00
ldm . modules . attention . CrossAttention . forward = hypernetwork . attention_CrossAttention_forward
2022-10-02 06:03:39 -06:00
ldm . modules . diffusionmodules . model . nonlinearity = diffusionmodules_model_nonlinearity
ldm . modules . diffusionmodules . model . AttnBlock . forward = diffusionmodules_model_AttnBlock_forward
2022-09-13 05:29:56 -06:00
2022-09-03 03:08:45 -06:00
2023-01-19 10:39:03 -07:00
def fix_checkpoint ( ) :
""" checkpoints are now added and removed in embedding/hypernet code, since torch doesn ' t want
checkpoints to be added when not training ( there ' s a warning) " " "
pass
2023-01-12 07:03:46 -07:00
def weighted_loss ( sd_model , pred , target , mean = True ) :
#Calculate the weight normally, but ignore the mean
loss = sd_model . _old_get_loss ( pred , target , mean = False )
#Check if we have weights available
weight = getattr ( sd_model , ' _custom_loss_weight ' , None )
if weight is not None :
loss * = weight
#Return the loss, as mean if specified
return loss . mean ( ) if mean else loss
def weighted_forward ( sd_model , x , c , w , * args , * * kwargs ) :
try :
#Temporarily append weights to a place accessible during loss calc
sd_model . _custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr ( sd_model , ' _old_get_loss ' ) :
sd_model . _old_get_loss = sd_model . get_loss
sd_model . get_loss = MethodType ( weighted_loss , sd_model )
#Run the standard forward function, but with the patched 'get_loss'
return sd_model . forward ( x , c , * args , * * kwargs )
finally :
try :
#Delete temporary weights if appended
del sd_model . _custom_loss_weight
except AttributeError as e :
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr ( sd_model , ' _old_get_loss ' ) :
sd_model . get_loss = sd_model . _old_get_loss
del sd_model . _old_get_loss
def apply_weighted_forward ( sd_model ) :
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model . weighted_forward = MethodType ( weighted_forward , sd_model )
def undo_weighted_forward ( sd_model ) :
try :
del sd_model . weighted_forward
except AttributeError as e :
pass
2022-09-03 03:08:45 -06:00
class StableDiffusionModelHijack :
fixes = None
comments = [ ]
2022-09-04 18:25:37 -06:00
layers = None
circular_enabled = False
2022-09-27 13:56:18 -06:00
clip = None
2023-01-04 06:04:38 -07:00
optimization_method = None
2022-09-03 03:08:45 -06:00
2023-01-07 23:37:33 -07:00
embedding_db = modules . textual_inversion . textual_inversion . EmbeddingDatabase ( )
2022-09-03 03:08:45 -06:00
2023-01-07 23:37:33 -07:00
def __init__ ( self ) :
self . embedding_db . add_embedding_dir ( cmd_opts . embeddings_dir )
2022-11-29 19:13:17 -07:00
2023-01-07 23:37:33 -07:00
def hijack ( self , m ) :
2022-12-31 08:06:35 -07:00
if type ( m . cond_stage_model ) == xlmr . BertSeriesModelWithTransformation :
2022-11-29 23:56:12 -07:00
model_embeddings = m . cond_stage_model . roberta . embeddings
model_embeddings . token_embedding = EmbeddingsWithFixes ( model_embeddings . word_embeddings , self )
2022-12-31 08:06:35 -07:00
m . cond_stage_model = sd_hijack_xlmr . FrozenXLMREmbedderWithCustomWords ( m . cond_stage_model , self )
2022-11-29 23:56:12 -07:00
elif type ( m . cond_stage_model ) == ldm . modules . encoders . modules . FrozenCLIPEmbedder :
2022-11-26 06:10:46 -07:00
model_embeddings = m . cond_stage_model . transformer . text_model . embeddings
model_embeddings . token_embedding = EmbeddingsWithFixes ( model_embeddings . token_embedding , self )
m . cond_stage_model = sd_hijack_clip . FrozenCLIPEmbedderWithCustomWords ( m . cond_stage_model , self )
2022-12-31 08:06:35 -07:00
2022-11-26 06:10:46 -07:00
elif type ( m . cond_stage_model ) == ldm . modules . encoders . modules . FrozenOpenCLIPEmbedder :
m . cond_stage_model . model . token_embedding = EmbeddingsWithFixes ( m . cond_stage_model . model . token_embedding , self )
m . cond_stage_model = sd_hijack_open_clip . FrozenOpenCLIPEmbedderWithCustomWords ( m . cond_stage_model , self )
2022-12-31 08:06:35 -07:00
2023-01-12 07:03:46 -07:00
apply_weighted_forward ( m )
2023-02-06 22:05:54 -07:00
if m . cond_stage_key == " edit " :
sd_hijack_unet . hijack_ddpm_edit ( )
2023-01-12 07:03:46 -07:00
2023-01-04 06:04:38 -07:00
self . optimization_method = apply_optimizations ( )
2022-12-31 08:06:35 -07:00
2022-09-27 13:56:18 -06:00
self . clip = m . cond_stage_model
2022-09-04 16:41:20 -06:00
2022-09-04 18:25:37 -06:00
def flatten ( el ) :
flattened = [ flatten ( children ) for children in el . children ( ) ]
res = [ el ]
for c in flattened :
res + = c
return res
self . layers = flatten ( m )
2022-09-29 06:40:28 -06:00
def undo_hijack ( self , m ) :
2022-12-31 08:06:35 -07:00
if type ( m . cond_stage_model ) == xlmr . BertSeriesModelWithTransformation :
2022-12-06 01:04:50 -07:00
m . cond_stage_model = m . cond_stage_model . wrapped
elif type ( m . cond_stage_model ) == sd_hijack_clip . FrozenCLIPEmbedderWithCustomWords :
2022-09-29 06:40:28 -06:00
m . cond_stage_model = m . cond_stage_model . wrapped
2022-11-26 06:10:46 -07:00
model_embeddings = m . cond_stage_model . transformer . text_model . embeddings
if type ( model_embeddings . token_embedding ) == EmbeddingsWithFixes :
model_embeddings . token_embedding = model_embeddings . token_embedding . wrapped
elif type ( m . cond_stage_model ) == sd_hijack_open_clip . FrozenOpenCLIPEmbedderWithCustomWords :
m . cond_stage_model . wrapped . model . token_embedding = m . cond_stage_model . wrapped . model . token_embedding . wrapped
m . cond_stage_model = m . cond_stage_model . wrapped
2022-09-29 06:40:28 -06:00
2023-01-28 05:24:29 -07:00
undo_optimizations ( )
2023-01-12 07:03:46 -07:00
undo_weighted_forward ( m )
2023-01-28 05:24:29 -07:00
2022-11-18 03:22:55 -07:00
self . apply_circular ( False )
2022-11-01 01:01:49 -06:00
self . layers = None
self . clip = None
2022-09-04 18:25:37 -06:00
def apply_circular ( self , enable ) :
if self . circular_enabled == enable :
return
self . circular_enabled = enable
for layer in [ layer for layer in self . layers if type ( layer ) == torch . nn . Conv2d ] :
layer . padding_mode = ' circular ' if enable else ' zeros '
2022-10-07 15:48:34 -06:00
def clear_comments ( self ) :
self . comments = [ ]
2023-01-06 15:45:28 -07:00
def get_prompt_lengths ( self , text ) :
_ , token_count = self . clip . process_texts ( [ text ] )
2022-09-03 03:08:45 -06:00
2023-01-06 15:45:28 -07:00
return token_count , self . clip . get_target_prompt_token_count ( token_count )
2022-09-03 03:08:45 -06:00
class EmbeddingsWithFixes ( torch . nn . Module ) :
def __init__ ( self , wrapped , embeddings ) :
super ( ) . __init__ ( )
self . wrapped = wrapped
self . embeddings = embeddings
def forward ( self , input_ids ) :
batch_fixes = self . embeddings . fixes
self . embeddings . fixes = None
inputs_embeds = self . wrapped ( input_ids )
2022-10-02 06:03:39 -06:00
if batch_fixes is None or len ( batch_fixes ) == 0 or max ( [ len ( x ) for x in batch_fixes ] ) == 0 :
return inputs_embeds
vecs = [ ]
for fixes , tensor in zip ( batch_fixes , inputs_embeds ) :
for offset , embedding in fixes :
2023-01-27 08:19:43 -07:00
emb = devices . cond_cast_unet ( embedding . vec )
2022-10-15 07:59:37 -06:00
emb_len = min ( tensor . shape [ 0 ] - offset - 1 , emb . shape [ 0 ] )
tensor = torch . cat ( [ tensor [ 0 : offset + 1 ] , emb [ 0 : emb_len ] , tensor [ offset + 1 + emb_len : ] ] )
2022-10-02 06:03:39 -06:00
vecs . append ( tensor )
2022-09-03 03:08:45 -06:00
2022-10-02 06:03:39 -06:00
return torch . stack ( vecs )
2022-09-03 03:08:45 -06:00
2022-09-04 17:16:36 -06:00
def add_circular_option_to_conv_2d ( ) :
conv2d_constructor = torch . nn . Conv2d . __init__
2022-09-04 16:41:20 -06:00
2022-09-04 17:16:36 -06:00
def conv2d_constructor_circular ( self , * args , * * kwargs ) :
return conv2d_constructor ( self , * args , padding_mode = ' circular ' , * * kwargs )
2022-09-04 16:41:20 -06:00
2022-09-04 17:16:36 -06:00
torch . nn . Conv2d . __init__ = conv2d_constructor_circular
2022-09-04 16:41:20 -06:00
2022-09-03 03:08:45 -06:00
model_hijack = StableDiffusionModelHijack ( )
2022-11-11 08:20:18 -07:00
def register_buffer ( self , name , attr ) :
"""
Fix register buffer bug for Mac OS .
"""
if type ( attr ) == torch . Tensor :
if attr . device != devices . device :
2022-11-12 00:17:55 -07:00
attr = attr . to ( device = devices . device , dtype = ( torch . float32 if devices . device . type == ' mps ' else None ) )
2022-11-11 08:20:18 -07:00
setattr ( self , name , attr )
ldm . models . diffusion . ddim . DDIMSampler . register_buffer = register_buffer
ldm . models . diffusion . plms . PLMSSampler . register_buffer = register_buffer