2022-10-13 23:00:38 -06:00
import collections
2024-06-15 23:04:31 -06:00
import importlib
2024-04-05 23:53:21 -06:00
import os
2022-09-17 03:05:04 -06:00
import sys
2023-05-02 00:08:00 -06:00
import threading
2024-06-15 23:04:31 -06:00
import enum
2023-05-02 00:08:00 -06:00
2022-09-17 03:05:04 -06:00
import torch
2022-10-27 21:49:39 -06:00
import re
2022-11-27 04:46:40 -07:00
import safetensors . torch
2023-09-15 10:59:44 -06:00
from omegaconf import OmegaConf , ListConfig
2022-12-08 17:14:35 -07:00
from urllib import request
import ldm . modules . midas as midas
2022-09-17 03:05:04 -06:00
2023-09-30 00:11:31 -06:00
from modules import paths , shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes , sd_models_config , sd_unet , sd_models_xl , cache , extra_networks , processing , lowvram , sd_hijack , patches
2023-01-27 01:28:12 -07:00
from modules . timer import Timer
2024-02-26 21:43:27 -07:00
from modules . shared import opts
2023-04-04 01:26:44 -06:00
import tomesd
2023-09-15 10:59:44 -06:00
import numpy as np
2022-09-27 10:01:13 -06:00
model_dir = " Stable-diffusion "
2023-01-25 09:15:42 -07:00
model_path = os . path . abspath ( os . path . join ( paths . models_path , model_dir ) )
2022-09-17 03:05:04 -06:00
checkpoints_list = { }
2023-07-03 03:17:20 -06:00
checkpoint_aliases = { }
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
2022-10-13 23:00:38 -06:00
checkpoints_loaded = collections . OrderedDict ( )
2022-09-17 03:05:04 -06:00
2023-01-13 23:56:59 -07:00
2024-06-15 23:04:31 -06:00
class ModelType ( enum . Enum ) :
SD1 = 1
SD2 = 2
SDXL = 3
SSD = 4
SD3 = 5
2023-08-29 23:54:31 -06:00
def replace_key ( d , key , new_key , value ) :
keys = list ( d . keys ( ) )
d [ new_key ] = value
if key not in keys :
return d
index = keys . index ( key )
keys [ index ] = new_key
new_d = { k : d [ k ] for k in keys }
d . clear ( )
d . update ( new_d )
return d
2023-01-13 23:56:59 -07:00
class CheckpointInfo :
def __init__ ( self , filename ) :
self . filename = filename
abspath = os . path . abspath ( filename )
2023-09-07 18:46:34 -06:00
abs_ckpt_dir = os . path . abspath ( shared . cmd_opts . ckpt_dir ) if shared . cmd_opts . ckpt_dir is not None else None
2023-01-13 23:56:59 -07:00
2023-07-31 22:08:11 -06:00
self . is_safetensors = os . path . splitext ( filename ) [ 1 ] . lower ( ) == " .safetensors "
2023-09-07 18:46:34 -06:00
if abs_ckpt_dir and abspath . startswith ( abs_ckpt_dir ) :
name = abspath . replace ( abs_ckpt_dir , ' ' )
2023-01-13 23:56:59 -07:00
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( filename )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
2023-07-31 22:08:11 -06:00
def read_metadata ( ) :
metadata = read_metadata_from_safetensors ( filename )
self . modelspec_thumbnail = metadata . pop ( ' modelspec.thumbnail ' , None )
return metadata
self . metadata = { }
if self . is_safetensors :
try :
self . metadata = cache . cached_data_for_file ( ' safetensors-metadata ' , " checkpoint/ " + name , filename , read_metadata )
except Exception as e :
errors . display ( e , f " reading metadata for { filename } " )
2023-01-19 08:58:08 -07:00
self . name = name
2023-01-29 00:20:19 -07:00
self . name_for_extra = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
2023-01-13 23:56:59 -07:00
self . model_name = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
self . hash = model_hash ( filename )
2023-01-14 05:55:40 -07:00
2023-05-09 13:17:58 -06:00
self . sha256 = hashes . sha256_from_cache ( self . filename , f " checkpoint/ { name } " )
2023-01-14 05:55:40 -07:00
self . shorthash = self . sha256 [ 0 : 10 ] if self . sha256 else None
2023-01-19 08:58:08 -07:00
self . title = name if self . shorthash is None else f ' { name } [ { self . shorthash } ] '
2023-07-30 04:48:27 -06:00
self . short_title = self . name_for_extra if self . shorthash is None else f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-01-19 08:58:08 -07:00
2023-08-09 05:47:44 -06:00
self . ids = [ self . hash , self . model_name , self . title , name , self . name_for_extra , f ' { name } [ { self . hash } ] ' ]
if self . shorthash :
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' , f ' { self . name_for_extra } [ { self . shorthash } ] ' ]
2023-01-13 23:56:59 -07:00
def register ( self ) :
checkpoints_list [ self . title ] = self
for id in self . ids :
2023-07-03 03:17:20 -06:00
checkpoint_aliases [ id ] = self
2023-01-13 23:56:59 -07:00
def calculate_shorthash ( self ) :
2023-05-09 13:17:58 -06:00
self . sha256 = hashes . sha256 ( self . filename , f " checkpoint/ { self . name } " )
2023-02-04 01:38:56 -07:00
if self . sha256 is None :
return
2023-08-09 05:47:44 -06:00
shorthash = self . sha256 [ 0 : 10 ]
if self . shorthash == self . sha256 [ 0 : 10 ] :
return self . shorthash
self . shorthash = shorthash
2023-01-13 23:56:59 -07:00
if self . shorthash not in self . ids :
2023-08-09 05:47:44 -06:00
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' , f ' { self . name_for_extra } [ { self . shorthash } ] ' ]
2023-01-13 23:56:59 -07:00
2023-08-29 23:54:31 -06:00
old_title = self . title
2023-01-19 08:58:08 -07:00
self . title = f ' { self . name } [ { self . shorthash } ] '
2023-07-30 04:48:27 -06:00
self . short_title = f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-08-29 23:54:31 -06:00
replace_key ( checkpoints_list , old_title , self . title , self )
2023-02-04 05:23:16 -07:00
self . register ( )
2023-01-19 08:58:08 -07:00
2023-01-13 23:56:59 -07:00
return self . shorthash
2022-09-17 03:05:04 -06:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
2023-05-10 00:02:23 -06:00
from transformers import logging , CLIPModel # noqa: F401
2022-09-17 03:05:04 -06:00
logging . set_verbosity_error ( )
except Exception :
pass
2022-10-02 12:09:10 -06:00
def setup_model ( ) :
2023-09-30 00:11:31 -06:00
""" called once at startup to do various one-time tasks related to SD models """
2023-05-29 01:18:15 -06:00
os . makedirs ( model_path , exist_ok = True )
2022-10-02 12:09:10 -06:00
2022-12-08 17:14:35 -07:00
enable_midas_autodownload ( )
2023-09-15 10:59:44 -06:00
patch_given_betas ( )
2022-09-29 18:59:36 -06:00
2023-07-30 04:48:27 -06:00
def checkpoint_tiles ( use_short = False ) :
return [ x . short_title if use_short else x . title for x in checkpoints_list . values ( ) ]
2022-09-28 15:59:44 -06:00
2022-09-17 03:05:04 -06:00
def list_models ( ) :
checkpoints_list . clear ( )
2023-07-03 03:17:20 -06:00
checkpoint_aliases . clear ( )
2022-09-17 03:05:04 -06:00
cmd_ckpt = shared . cmd_opts . ckpt
2023-02-19 04:49:07 -07:00
if shared . cmd_opts . no_download_sd_model or cmd_ckpt != shared . sd_model_file or os . path . exists ( cmd_ckpt ) :
2023-02-19 04:37:40 -07:00
model_url = None
2024-04-22 12:08:57 -06:00
expected_sha256 = None
2023-02-19 04:37:40 -07:00
else :
2024-04-05 23:53:21 -06:00
model_url = f " { shared . hf_endpoint } /runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors "
2024-04-22 12:08:57 -06:00
expected_sha256 = ' 6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa '
2023-02-19 04:37:40 -07:00
2024-04-22 12:08:57 -06:00
model_list = modelloader . load_models ( model_path = model_path , model_url = model_url , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , download_name = " v1-5-pruned-emaonly.safetensors " , ext_blacklist = [ " .vae.ckpt " , " .vae.safetensors " ] , hash_prefix = expected_sha256 )
2023-02-19 04:37:40 -07:00
2022-09-17 03:05:04 -06:00
if os . path . exists ( cmd_ckpt ) :
2023-01-13 23:56:59 -07:00
checkpoint_info = CheckpointInfo ( cmd_ckpt )
checkpoint_info . register ( )
shared . opts . data [ ' sd_model_checkpoint ' ] = checkpoint_info . title
2022-09-17 03:05:04 -06:00
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
2022-09-27 10:01:13 -06:00
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
2023-01-13 23:56:59 -07:00
2023-07-30 04:48:27 -06:00
for filename in model_list :
2023-01-13 23:56:59 -07:00
checkpoint_info = CheckpointInfo ( filename )
checkpoint_info . register ( )
2022-10-08 14:26:48 -06:00
2023-07-30 04:48:27 -06:00
re_strip_checksum = re . compile ( r " \ s* \ [[^]]+] \ s*$ " )
2023-01-13 23:56:59 -07:00
def get_closet_checkpoint_match ( search_string ) :
2023-08-12 03:39:59 -06:00
if not search_string :
return None
2023-07-03 03:17:20 -06:00
checkpoint_info = checkpoint_aliases . get ( search_string , None )
2023-01-13 23:56:59 -07:00
if checkpoint_info is not None :
2023-01-14 00:25:21 -07:00
return checkpoint_info
2022-09-30 02:42:40 -06:00
2023-01-13 23:56:59 -07:00
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-17 03:05:04 -06:00
2023-07-30 04:48:27 -06:00
search_string_without_checksum = re . sub ( re_strip_checksum , ' ' , search_string )
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string_without_checksum in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-28 15:30:09 -06:00
return None
2022-09-17 03:05:04 -06:00
2022-09-30 02:42:40 -06:00
2022-09-17 03:05:04 -06:00
def model_hash ( filename ) :
2023-01-13 23:56:59 -07:00
""" old hash that only looks at a small part of the file and is prone to collisions """
2022-09-17 03:05:04 -06:00
try :
with open ( filename , " rb " ) as file :
import hashlib
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
return ' NOFILE '
def select_checkpoint ( ) :
2023-05-26 13:08:53 -06:00
""" Raises `FileNotFoundError` if no checkpoints are found. """
2022-09-17 03:05:04 -06:00
model_checkpoint = shared . opts . sd_model_checkpoint
2023-05-11 09:28:15 -06:00
2023-07-03 03:17:20 -06:00
checkpoint_info = checkpoint_aliases . get ( model_checkpoint , None )
2022-09-17 03:05:04 -06:00
if checkpoint_info is not None :
return checkpoint_info
if len ( checkpoints_list ) == 0 :
2023-05-26 13:08:53 -06:00
error_message = " No checkpoints found. When searching for checkpoints, looked at: "
2022-10-02 12:09:10 -06:00
if shared . cmd_opts . ckpt is not None :
2023-05-26 13:08:53 -06:00
error_message + = f " \n - file { os . path . abspath ( shared . cmd_opts . ckpt ) } "
error_message + = f " \n - directory { model_path } "
2022-10-02 12:09:10 -06:00
if shared . cmd_opts . ckpt_dir is not None :
2023-05-26 13:08:53 -06:00
error_message + = f " \n - directory { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } "
error_message + = " Can ' t run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. "
raise FileNotFoundError ( error_message )
2022-09-17 03:05:04 -06:00
checkpoint_info = next ( iter ( checkpoints_list . values ( ) ) )
if model_checkpoint is not None :
print ( f " Checkpoint { model_checkpoint } not found; loading fallback { checkpoint_info . title } " , file = sys . stderr )
return checkpoint_info
2023-12-01 20:58:05 -07:00
checkpoint_dict_replacements_sd1 = {
2022-10-18 23:42:22 -06:00
' cond_stage_model.transformer.embeddings. ' : ' cond_stage_model.transformer.text_model.embeddings. ' ,
' cond_stage_model.transformer.encoder. ' : ' cond_stage_model.transformer.text_model.encoder. ' ,
' cond_stage_model.transformer.final_layer_norm. ' : ' cond_stage_model.transformer.text_model.final_layer_norm. ' ,
}
2023-12-01 20:58:05 -07:00
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
' conditioner.embedders.0. ' : ' cond_stage_model. ' ,
}
2022-10-18 23:42:22 -06:00
2023-12-01 20:58:05 -07:00
def transform_checkpoint_dict_key ( k , replacements ) :
for text , replacement in replacements . items ( ) :
2022-10-18 23:42:22 -06:00
if k . startswith ( text ) :
k = replacement + k [ len ( text ) : ]
return k
2022-10-09 01:23:31 -06:00
def get_state_dict_from_checkpoint ( pl_sd ) :
2022-11-27 22:39:59 -07:00
pl_sd = pl_sd . pop ( " state_dict " , pl_sd )
pl_sd . pop ( " state_dict " , None )
2022-10-18 23:42:22 -06:00
2023-12-01 20:58:05 -07:00
is_sd2_turbo = ' conditioner.embedders.0.model.ln_final.weight ' in pl_sd and pl_sd [ ' conditioner.embedders.0.model.ln_final.weight ' ] . size ( ) [ 0 ] == 1024
2022-10-18 23:42:22 -06:00
sd = { }
for k , v in pl_sd . items ( ) :
2023-12-01 20:58:05 -07:00
if is_sd2_turbo :
new_key = transform_checkpoint_dict_key ( k , checkpoint_dict_replacements_sd2_turbo )
else :
new_key = transform_checkpoint_dict_key ( k , checkpoint_dict_replacements_sd1 )
2022-10-18 23:42:22 -06:00
if new_key is not None :
sd [ new_key ] = v
2022-10-09 01:23:31 -06:00
2022-10-19 03:45:30 -06:00
pl_sd . clear ( )
pl_sd . update ( sd )
return pl_sd
2022-10-09 01:23:31 -06:00
2023-03-14 00:10:26 -06:00
def read_metadata_from_safetensors ( filename ) :
import json
with open ( filename , mode = " rb " ) as file :
metadata_len = file . read ( 8 )
metadata_len = int . from_bytes ( metadata_len , " little " )
json_start = file . read ( 2 )
assert metadata_len > 2 and json_start in ( b ' { " ' , b " { ' " ) , f " { filename } is not a safetensors file "
res = { }
2024-04-26 04:52:21 -06:00
try :
json_data = json_start + file . read ( metadata_len - 2 )
json_obj = json . loads ( json_data )
for k , v in json_obj . get ( " __metadata__ " , { } ) . items ( ) :
res [ k ] = v
if isinstance ( v , str ) and v [ 0 : 1 ] == ' { ' :
try :
res [ k ] = json . loads ( v )
except Exception :
pass
2024-04-26 06:21:12 -06:00
except Exception :
2024-04-26 04:52:21 -06:00
errors . report ( f " Error reading metadata from file: { filename } " , exc_info = True )
2024-04-26 06:17:37 -06:00
2023-03-14 00:10:26 -06:00
return res
2022-11-27 05:51:29 -07:00
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
2023-06-27 00:19:04 -06:00
device = map_location or shared . weight_load_location or devices . get_optimal_device_name ( )
2023-06-16 10:10:15 -06:00
if not shared . opts . disable_mmap_load_safetensors :
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
else :
pl_sd = safetensors . torch . load ( open ( checkpoint_file , ' rb ' ) . read ( ) )
2023-06-27 00:19:04 -06:00
pl_sd = { k : v . to ( device ) for k , v in pl_sd . items ( ) }
2022-11-27 05:51:29 -07:00
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
if print_global_state and " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = get_state_dict_from_checkpoint ( pl_sd )
return sd
2023-01-27 01:28:12 -07:00
def get_checkpoint_state_dict ( checkpoint_info : CheckpointInfo , timer ) :
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
2023-09-18 02:45:42 -06:00
# move to end as latest
2023-09-15 04:23:23 -06:00
checkpoints_loaded . move_to_end ( checkpoint_info )
2023-01-27 01:28:12 -07:00
return checkpoints_loaded [ checkpoint_info ]
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
res = read_state_dict ( checkpoint_info . filename )
timer . record ( " load weights from disk " )
return res
2023-08-06 08:01:07 -06:00
class SkipWritingToConfig :
""" This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight. """
skip = False
previous = None
def __enter__ ( self ) :
self . previous = SkipWritingToConfig . skip
SkipWritingToConfig . skip = True
return self
def __exit__ ( self , exc_type , exc_value , exc_traceback ) :
SkipWritingToConfig . skip = self . previous
2023-11-19 00:50:06 -07:00
def check_fp8 ( model ) :
if model is None :
return None
if devices . get_optimal_device_name ( ) == " mps " :
enable_fp8 = False
elif shared . opts . fp8_storage == " Enable " :
enable_fp8 = True
elif getattr ( model , " is_sdxl " , False ) and shared . opts . fp8_storage == " Enable for SDXL " :
enable_fp8 = True
else :
enable_fp8 = False
return enable_fp8
2024-06-15 23:04:31 -06:00
def set_model_type ( model , state_dict ) :
model . is_sd1 = False
model . is_sd2 = False
model . is_sdxl = False
model . is_ssd = False
2024-06-15 23:18:05 -06:00
model . is_sd3 = False
2024-06-15 23:04:31 -06:00
if " model.diffusion_model.x_embedder.proj.weight " in state_dict :
model . is_sd3 = True
model . model_type = ModelType . SD3
elif hasattr ( model , ' conditioner ' ) :
model . is_sdxl = True
if ' model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight ' not in state_dict . keys ( ) :
model . is_ssd = True
model . model_type = ModelType . SSD
else :
model . model_type = ModelType . SDXL
elif hasattr ( model . cond_stage_model , ' model ' ) :
model . is_sd2 = True
model . model_type = ModelType . SD2
else :
model . is_sd1 = True
model . model_type = ModelType . SD1
def set_model_fields ( model ) :
if not hasattr ( model , ' latent_channels ' ) :
model . latent_channels = 4
2024-06-16 02:04:19 -06:00
2023-01-27 01:28:12 -07:00
def load_model_weights ( model , checkpoint_info : CheckpointInfo , state_dict , timer ) :
2023-01-13 23:56:59 -07:00
sd_model_hash = checkpoint_info . calculate_shorthash ( )
2023-01-27 01:28:12 -07:00
timer . record ( " calculate hash " )
2023-11-24 21:35:09 -07:00
if devices . fp8 :
2023-11-19 00:50:06 -07:00
# prevent model to load state dict in fp8
model . half ( )
2023-08-06 08:01:07 -06:00
if not SkipWritingToConfig . skip :
shared . opts . data [ " sd_model_checkpoint " ] = checkpoint_info . title
2022-10-08 14:26:48 -06:00
2023-01-27 01:28:12 -07:00
if state_dict is None :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-08 20:54:21 -07:00
2024-06-15 23:04:31 -06:00
set_model_type ( model , state_dict )
set_model_fields ( model )
2023-07-13 12:17:50 -06:00
if model . is_sdxl :
2023-07-11 12:16:43 -06:00
sd_models_xl . extend_sdxl ( model )
2023-11-05 06:43:49 -07:00
if model . is_ssd :
2023-11-05 09:46:20 -07:00
sd_hijack . model_hijack . convert_sdxl_to_ssd ( model )
2023-11-05 09:32:21 -07:00
2023-01-27 01:28:12 -07:00
if shared . opts . sd_checkpoint_cache > 0 :
# cache newly loaded model
2023-10-13 23:01:04 -06:00
checkpoints_loaded [ checkpoint_info ] = state_dict . copy ( )
2023-10-07 01:36:01 -06:00
2024-06-28 15:38:52 -06:00
if hasattr ( model , " before_load_weights " ) :
model . before_load_weights ( state_dict )
2023-10-07 01:36:01 -06:00
model . load_state_dict ( state_dict , strict = False )
timer . record ( " apply weights to model " )
2023-08-03 21:43:27 -06:00
2024-06-28 15:38:52 -06:00
if hasattr ( model , " after_load_weights " ) :
model . after_load_weights ( state_dict )
2023-08-03 21:43:27 -06:00
del state_dict
2023-01-27 01:28:12 -07:00
2024-06-08 20:11:11 -06:00
# Set is_sdxl_inpaint flag.
2024-06-08 20:15:37 -06:00
# Checks Unet structure to detect inpaint model. The inpaint model's
# checkpoint state_dict does not contain the key
# 'diffusion_model.input_blocks.0.0.weight'.
2024-06-08 20:11:11 -06:00
diffusion_model_input = model . model . state_dict ( ) . get (
' diffusion_model.input_blocks.0.0.weight '
)
model . is_sdxl_inpaint = (
model . is_sdxl and
diffusion_model_input is not None and
diffusion_model_input . shape [ 1 ] == 9
)
2023-01-27 01:28:12 -07:00
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
timer . record ( " apply channels_last " )
2022-09-17 03:05:04 -06:00
2023-08-16 22:54:07 -06:00
if shared . cmd_opts . no_half :
model . float ( )
2023-12-02 12:09:18 -07:00
model . alphas_cumprod_original = model . alphas_cumprod
2023-08-22 22:10:43 -06:00
devices . dtype_unet = torch . float32
2024-05-16 17:50:06 -06:00
assert shared . cmd_opts . precision != " half " , " Cannot use --precision half with --no-half "
2023-08-16 22:54:07 -06:00
timer . record ( " apply float() " )
else :
2023-01-27 01:28:12 -07:00
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
2022-09-17 03:05:04 -06:00
2023-01-27 01:28:12 -07:00
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
2022-11-02 05:41:29 -06:00
2023-11-29 15:38:53 -07:00
alphas_cumprod = model . alphas_cumprod
model . alphas_cumprod = None
2023-01-27 01:28:12 -07:00
model . half ( )
2023-11-29 15:38:53 -07:00
model . alphas_cumprod = alphas_cumprod
model . alphas_cumprod_original = alphas_cumprod
2023-01-27 01:28:12 -07:00
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
2022-11-02 05:41:29 -06:00
2023-08-22 22:10:43 -06:00
devices . dtype_unet = torch . float16
2023-01-27 01:28:12 -07:00
timer . record ( " apply half() " )
2023-10-23 11:49:05 -06:00
2024-03-01 20:53:53 -07:00
apply_alpha_schedule_override ( model )
2023-11-21 04:59:34 -07:00
for module in model . modules ( ) :
if hasattr ( module , ' fp16_weight ' ) :
del module . fp16_weight
if hasattr ( module , ' fp16_bias ' ) :
del module . fp16_bias
2023-11-19 00:50:06 -07:00
if check_fp8 ( model ) :
2023-10-23 11:49:05 -06:00
devices . fp8 = True
2023-11-19 00:50:06 -07:00
first_stage = model . first_stage_model
model . first_stage_model = None
for module in model . modules ( ) :
2023-11-21 04:59:34 -07:00
if isinstance ( module , ( torch . nn . Conv2d , torch . nn . Linear ) ) :
if shared . opts . cache_fp16_weight :
2023-12-02 07:06:47 -07:00
module . fp16_weight = module . weight . data . clone ( ) . cpu ( ) . half ( )
2023-11-21 04:59:34 -07:00
if module . bias is not None :
2023-12-02 07:06:47 -07:00
module . fp16_bias = module . bias . data . clone ( ) . cpu ( ) . half ( )
2023-11-19 00:50:06 -07:00
module . to ( torch . float8_e4m3fn )
model . first_stage_model = first_stage
2023-10-28 01:24:26 -06:00
timer . record ( " apply fp8 " )
else :
devices . fp8 = False
2022-09-17 03:05:04 -06:00
2023-01-27 01:28:12 -07:00
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
2022-09-17 03:05:04 -06:00
2023-01-27 01:28:12 -07:00
model . first_stage_model . to ( devices . dtype_vae )
timer . record ( " apply dtype to VAE " )
2022-11-02 05:41:29 -06:00
2022-11-08 20:54:21 -07:00
# clean up cache if limit is reached
2023-01-27 01:28:12 -07:00
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache :
checkpoints_loaded . popitem ( last = False )
2022-10-31 03:27:27 -06:00
2022-09-17 03:05:04 -06:00
model . sd_model_hash = sd_model_hash
2023-01-13 23:56:59 -07:00
model . sd_model_checkpoint = checkpoint_info . filename
2022-10-08 14:26:48 -06:00
model . sd_checkpoint_info = checkpoint_info
2023-01-14 05:55:40 -07:00
shared . opts . data [ " sd_checkpoint_hash " ] = checkpoint_info . sha256
2022-09-17 03:05:04 -06:00
2023-07-11 12:16:43 -06:00
if hasattr ( model , ' logvar ' ) :
model . logvar = model . logvar . to ( devices . device ) # fix for training
2023-01-01 14:38:09 -07:00
2022-11-12 21:11:14 -07:00
sd_vae . delete_base_vae ( )
2022-11-02 22:10:53 -06:00
sd_vae . clear_loaded_vae ( )
2023-08-06 23:07:09 -06:00
vae_file , vae_source = sd_vae . resolve_vae ( checkpoint_info . filename ) . tuple ( )
2023-01-14 09:56:09 -07:00
sd_vae . load_vae ( model , vae_file , vae_source )
2023-01-27 01:28:12 -07:00
timer . record ( " load VAE " )
2022-11-01 23:51:46 -06:00
2022-09-17 03:05:04 -06:00
2022-12-08 17:14:35 -07:00
def enable_midas_autodownload ( ) :
"""
Gives the ldm . modules . midas . api . load_model function automatic downloading .
When the 512 - depth - ema model , and other future models like it , is loaded ,
it calls midas . api . load_model to load the associated midas depth model .
This function applies a wrapper to download the model to the correct
location automatically .
"""
2023-01-25 09:15:42 -07:00
midas_path = os . path . join ( paths . models_path , ' midas ' )
2022-12-08 17:14:35 -07:00
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k , v in midas . api . ISL_PATHS . items ( ) :
file_name = os . path . basename ( v )
midas . api . ISL_PATHS [ k ] = os . path . join ( midas_path , file_name )
midas_urls = {
" dpt_large " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " ,
" dpt_hybrid " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " ,
" midas_v21 " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " ,
" midas_v21_small " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " ,
}
midas . api . load_model_inner = midas . api . load_model
def load_model_wrapper ( model_type ) :
path = midas . api . ISL_PATHS [ model_type ]
if not os . path . exists ( path ) :
if not os . path . exists ( midas_path ) :
2024-04-05 23:53:21 -06:00
os . mkdir ( midas_path )
2023-05-11 09:28:15 -06:00
2022-12-08 17:14:35 -07:00
print ( f " Downloading midas model weights for { model_type } to { path } " )
request . urlretrieve ( midas_urls [ model_type ] , path )
print ( f " { model_type } downloaded " )
return midas . api . load_model_inner ( model_type )
midas . api . load_model = load_model_wrapper
2023-01-04 02:35:07 -07:00
2023-09-15 10:59:44 -06:00
def patch_given_betas ( ) :
2023-09-30 00:11:31 -06:00
import ldm . models . diffusion . ddpm
2023-09-15 10:59:44 -06:00
def patched_register_schedule ( * args , * * kwargs ) :
2023-09-30 00:11:31 -06:00
""" a modified version of register_schedule function that converts plain list from Omegaconf into numpy """
if isinstance ( args [ 1 ] , ListConfig ) :
args = ( args [ 0 ] , np . array ( args [ 1 ] ) , * args [ 2 : ] )
2023-09-15 10:59:44 -06:00
original_register_schedule ( * args , * * kwargs )
2023-09-30 00:11:31 -06:00
original_register_schedule = patches . patch ( __name__ , ldm . models . diffusion . ddpm . DDPM , ' register_schedule ' , patched_register_schedule )
2023-09-15 10:59:44 -06:00
2024-06-15 23:04:31 -06:00
def repair_config ( sd_config , state_dict = None ) :
2023-01-27 01:28:12 -07:00
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
2023-01-10 06:51:04 -07:00
2023-07-13 08:32:35 -06:00
if hasattr ( sd_config . model . params , ' unet_config ' ) :
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
2024-05-17 11:34:04 -06:00
elif shared . cmd_opts . upcast_sampling or shared . cmd_opts . precision == " half " :
2023-07-13 08:32:35 -06:00
sd_config . model . params . unet_config . params . use_fp16 = True
2023-01-10 06:51:04 -07:00
2024-06-15 23:04:31 -06:00
if hasattr ( sd_config . model . params , ' first_stage_config ' ) :
if getattr ( sd_config . model . params . first_stage_config . params . ddconfig , " attn_type " , None ) == " vanilla-xformers " and not shared . xformers_available :
sd_config . model . params . first_stage_config . params . ddconfig . attn_type = " vanilla "
2023-03-26 14:55:29 -06:00
2023-03-24 20:48:16 -06:00
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr ( sd_config . model . params , " noise_aug_config " ) and hasattr ( sd_config . model . params . noise_aug_config . params , " clip_stats_path " ) :
karlo_path = os . path . join ( paths . models_path , ' karlo ' )
sd_config . model . params . noise_aug_config . params . clip_stats_path = sd_config . model . params . noise_aug_config . params . clip_stats_path . replace ( " checkpoints/karlo_models " , karlo_path )
2024-05-16 14:39:02 -06:00
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
2024-05-16 18:06:04 -06:00
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr ( sd_config . model . params , " network_config " ) :
sd_config . model . params . network_config . params . use_checkpoint = False
if hasattr ( sd_config . model . params , " unet_config " ) :
sd_config . model . params . unet_config . params . use_checkpoint = False
2024-05-16 14:39:02 -06:00
2024-03-01 20:54:11 -07:00
2024-06-15 23:04:31 -06:00
2024-03-01 20:54:11 -07:00
def rescale_zero_terminal_snr_abar ( alphas_cumprod ) :
alphas_bar_sqrt = alphas_cumprod . sqrt ( )
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt [ 0 ] . clone ( )
alphas_bar_sqrt_T = alphas_bar_sqrt [ - 1 ] . clone ( )
# Shift so the last timestep is zero.
alphas_bar_sqrt - = ( alphas_bar_sqrt_T )
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt * = alphas_bar_sqrt_0 / ( alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt * * 2 # Revert sqrt
alphas_bar [ - 1 ] = 4.8973451890853435e-08
return alphas_bar
2024-02-26 21:43:27 -07:00
def apply_alpha_schedule_override ( sd_model , p = None ) :
2024-03-01 20:54:11 -07:00
"""
Applies an override to the alpha schedule of the model according to settings .
- downcasts the alpha schedule to half precision
- rescales the alpha schedule to have zero terminal SNR
"""
if not hasattr ( sd_model , ' alphas_cumprod ' ) or not hasattr ( sd_model , ' alphas_cumprod_original ' ) :
return
sd_model . alphas_cumprod = sd_model . alphas_cumprod_original . to ( shared . device )
if opts . use_downcasted_alpha_bar :
if p is not None :
p . extra_generation_params [ ' Downcast alphas_cumprod ' ] = opts . use_downcasted_alpha_bar
sd_model . alphas_cumprod = sd_model . alphas_cumprod . half ( ) . to ( shared . device )
if opts . sd_noise_schedule == " Zero Terminal SNR " :
if p is not None :
p . extra_generation_params [ ' Noise Schedule ' ] = opts . sd_noise_schedule
sd_model . alphas_cumprod = rescale_zero_terminal_snr_abar ( sd_model . alphas_cumprod ) . to ( shared . device )
2023-01-27 01:28:12 -07:00
2023-02-05 01:20:47 -07:00
sd1_clip_weight = ' cond_stage_model.transformer.text_model.embeddings.token_embedding.weight '
sd2_clip_weight = ' cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight '
2023-07-12 14:52:43 -06:00
sdxl_clip_weight = ' conditioner.embedders.1.model.ln_final.weight '
2023-07-14 00:16:01 -06:00
sdxl_refiner_clip_weight = ' conditioner.embedders.0.model.ln_final.weight '
2023-02-05 01:20:47 -07:00
2023-05-02 00:08:00 -06:00
class SdModelData :
def __init__ ( self ) :
self . sd_model = None
2023-07-31 15:24:48 -06:00
self . loaded_sd_models = [ ]
2023-05-18 06:47:43 -06:00
self . was_loaded_at_least_once = False
2023-05-02 00:08:00 -06:00
self . lock = threading . Lock ( )
def get_sd_model ( self ) :
2023-05-18 06:47:43 -06:00
if self . was_loaded_at_least_once :
return self . sd_model
2023-05-02 00:08:00 -06:00
if self . sd_model is None :
with self . lock :
2023-05-18 06:47:43 -06:00
if self . sd_model is not None or self . was_loaded_at_least_once :
2023-05-14 04:27:50 -06:00
return self . sd_model
2023-05-02 00:08:00 -06:00
try :
load_model ( )
2023-07-31 15:24:48 -06:00
2023-05-02 00:08:00 -06:00
except Exception as e :
2023-05-26 13:15:59 -06:00
errors . display ( e , " loading stable diffusion model " , full_traceback = True )
2023-05-02 00:08:00 -06:00
print ( " " , file = sys . stderr )
print ( " Stable diffusion model failed to load " , file = sys . stderr )
self . sd_model = None
return self . sd_model
2023-08-20 00:00:14 -06:00
def set_sd_model ( self , v , already_loaded = False ) :
2023-05-02 00:08:00 -06:00
self . sd_model = v
2023-08-20 00:00:14 -06:00
if already_loaded :
2023-08-20 04:44:37 -06:00
sd_vae . base_vae = getattr ( v , " base_vae " , None )
sd_vae . loaded_vae_file = getattr ( v , " loaded_vae_file " , None )
sd_vae . checkpoint_info = v . sd_checkpoint_info
2023-05-02 00:08:00 -06:00
2023-07-31 15:24:48 -06:00
try :
self . loaded_sd_models . remove ( v )
except ValueError :
pass
if v is not None :
self . loaded_sd_models . insert ( 0 , v )
2023-05-02 00:08:00 -06:00
model_data = SdModelData ( )
2023-07-12 14:52:43 -06:00
def get_empty_cond ( sd_model ) :
2023-07-31 15:24:48 -06:00
p = processing . StableDiffusionProcessingTxt2Img ( )
extra_networks . activate ( p , { } )
2024-06-28 02:15:34 -06:00
if hasattr ( sd_model , ' get_learned_conditioning ' ) :
2023-07-12 14:52:43 -06:00
d = sd_model . get_learned_conditioning ( [ " " ] )
else :
2024-06-26 14:22:00 -06:00
d = sd_model . cond_stage_model ( [ " " ] )
2024-06-28 02:15:34 -06:00
if isinstance ( d , dict ) :
d = d [ ' crossattn ' ]
2024-06-26 14:22:00 -06:00
2024-06-28 02:15:34 -06:00
return d
2023-07-12 14:52:43 -06:00
2023-07-31 15:24:48 -06:00
def send_model_to_cpu ( m ) :
2024-04-22 11:35:25 -06:00
if m is not None :
if m . lowvram :
lowvram . send_everything_to_cpu ( )
else :
m . to ( devices . cpu )
2023-07-31 15:24:48 -06:00
devices . torch_gc ( )
2023-08-22 09:49:08 -06:00
def model_target_device ( m ) :
if lowvram . is_needed ( m ) :
2023-08-16 03:11:01 -06:00
return devices . cpu
else :
return devices . device
2023-07-31 15:24:48 -06:00
def send_model_to_device ( m ) :
2023-08-22 09:49:08 -06:00
lowvram . apply ( m )
if not m . lowvram :
2023-07-31 15:24:48 -06:00
m . to ( shared . device )
def send_model_to_trash ( m ) :
m . to ( device = " meta " )
devices . torch_gc ( )
2024-06-15 23:04:31 -06:00
def instantiate_from_config ( config , state_dict = None ) :
constructor = get_obj_from_str ( config [ " target " ] )
params = { * * config . get ( " params " , { } ) }
if state_dict and " state_dict " in params and params [ " state_dict " ] is None :
params [ " state_dict " ] = state_dict
return constructor ( * * params )
def get_obj_from_str ( string , reload = False ) :
module , cls = string . rsplit ( " . " , 1 )
if reload :
module_imp = importlib . import_module ( module )
importlib . reload ( module_imp )
return getattr ( importlib . import_module ( module , package = None ) , cls )
2023-05-02 00:08:00 -06:00
def load_model ( checkpoint_info = None , already_loaded_state_dict = None ) :
2023-07-31 15:24:48 -06:00
from modules import sd_hijack
2022-10-20 17:01:27 -06:00
checkpoint_info = checkpoint_info or select_checkpoint ( )
2022-10-08 14:26:48 -06:00
2023-07-31 15:24:48 -06:00
timer = Timer ( )
2023-05-02 00:08:00 -06:00
if model_data . sd_model :
2023-07-31 15:24:48 -06:00
send_model_to_trash ( model_data . sd_model )
2023-05-02 00:08:00 -06:00
model_data . sd_model = None
2022-11-01 01:01:49 -06:00
devices . torch_gc ( )
2023-07-31 15:24:48 -06:00
timer . record ( " unload existing model " )
2022-12-11 08:19:46 -07:00
2023-01-27 01:28:12 -07:00
if already_loaded_state_dict is not None :
state_dict = already_loaded_state_dict
else :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-01 01:01:49 -06:00
2023-01-27 01:28:12 -07:00
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
2023-07-14 00:19:08 -06:00
clip_is_included_into_sd = any ( x for x in [ sd1_clip_weight , sd2_clip_weight , sdxl_clip_weight , sdxl_refiner_clip_weight ] if x in state_dict )
2022-11-26 11:28:44 -07:00
2023-01-27 01:28:12 -07:00
timer . record ( " find config " )
2023-01-10 06:51:04 -07:00
2023-01-27 01:28:12 -07:00
sd_config = OmegaConf . load ( checkpoint_config )
2024-06-15 23:04:31 -06:00
repair_config ( sd_config , state_dict )
2023-01-27 01:28:12 -07:00
timer . record ( " load config " )
print ( f " Creating model from config: { checkpoint_config } " )
2023-01-11 08:54:04 -07:00
2023-01-27 01:28:12 -07:00
sd_model = None
2023-01-10 07:46:59 -07:00
try :
2023-07-18 09:10:04 -06:00
with sd_disable_initialization . DisableInitialization ( disable_clip = clip_is_included_into_sd or shared . cmd_opts . do_not_download_clip ) :
2023-07-24 13:08:08 -06:00
with sd_disable_initialization . InitializeOnMeta ( ) :
2024-06-15 23:04:31 -06:00
sd_model = instantiate_from_config ( sd_config . model , state_dict )
2023-07-24 13:08:08 -06:00
except Exception as e :
errors . display ( e , " creating model quickly " , full_traceback = True )
2023-01-11 00:24:56 -07:00
if sd_model is None :
2023-01-10 07:46:59 -07:00
print ( ' Failed to create model quickly; will retry using slow method. ' , file = sys . stderr )
2023-07-24 13:08:08 -06:00
with sd_disable_initialization . InitializeOnMeta ( ) :
2024-06-15 23:04:31 -06:00
sd_model = instantiate_from_config ( sd_config . model , state_dict )
2023-01-04 02:35:07 -07:00
2023-01-27 01:28:12 -07:00
sd_model . used_config = checkpoint_config
2023-01-10 06:51:04 -07:00
2023-01-27 01:28:12 -07:00
timer . record ( " create model " )
2022-09-17 03:05:04 -06:00
2023-08-16 03:11:01 -06:00
if shared . cmd_opts . no_half :
weight_dtype_conversion = None
else :
weight_dtype_conversion = {
' first_stage_model ' : None ,
2023-11-29 15:38:53 -07:00
' alphas_cumprod ' : None ,
2023-08-16 03:11:01 -06:00
' ' : torch . float16 ,
}
2023-08-22 09:49:08 -06:00
with sd_disable_initialization . LoadStateDictOnMeta ( state_dict , device = model_target_device ( sd_model ) , weight_dtype_conversion = weight_dtype_conversion ) :
2023-07-24 13:08:08 -06:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2024-06-15 23:04:31 -06:00
2023-07-31 15:24:48 -06:00
timer . record ( " load weights from state dict " )
2023-01-10 06:51:04 -07:00
2023-07-31 15:24:48 -06:00
send_model_to_device ( sd_model )
2023-01-27 01:28:12 -07:00
timer . record ( " move model to device " )
2022-09-17 03:05:04 -06:00
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 01:28:12 -07:00
timer . record ( " hijack " )
2022-09-17 03:05:04 -06:00
sd_model . eval ( )
2023-07-31 15:24:48 -06:00
model_data . set_sd_model ( sd_model )
2023-05-18 06:47:43 -06:00
model_data . was_loaded_at_least_once = True
2022-10-22 03:23:45 -06:00
2023-01-03 08:39:14 -07:00
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
2023-01-27 01:28:12 -07:00
timer . record ( " load textual inversion embeddings " )
2022-10-22 11:15:12 -06:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 01:28:12 -07:00
timer . record ( " scripts callbacks " )
2023-01-10 06:51:04 -07:00
2023-05-21 15:13:53 -06:00
with devices . autocast ( ) , torch . no_grad ( ) :
2023-07-12 14:52:43 -06:00
sd_model . cond_stage_model_empty_prompt = get_empty_cond ( sd_model )
2023-05-21 15:13:53 -06:00
timer . record ( " calculate empty prompt " )
2023-01-27 01:28:12 -07:00
print ( f " Model loaded in { timer . summary ( ) } . " )
2022-12-31 09:27:02 -07:00
2022-09-17 03:05:04 -06:00
return sd_model
2023-07-31 15:24:48 -06:00
def reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer ) :
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data . loaded_sd_models .
If it is loaded , returns that ( moving it to GPU if necessary , and moving the currently loadded model to CPU if necessary ) .
If not , returns the model that can be used to load weights from checkpoint_info ' s file.
If no such model exists , returns None .
2024-03-03 23:37:23 -07:00
Additionally deletes loaded models that are over the limit set in settings ( sd_checkpoints_limit ) .
2023-07-31 15:24:48 -06:00
"""
2024-03-21 16:28:38 -06:00
if sd_model is not None and sd_model . sd_checkpoint_info . filename == checkpoint_info . filename :
return sd_model
if shared . opts . sd_checkpoints_keep_in_cpu :
send_model_to_cpu ( sd_model )
timer . record ( " send model to cpu " )
2023-07-31 15:24:48 -06:00
already_loaded = None
for i in reversed ( range ( len ( model_data . loaded_sd_models ) ) ) :
loaded_model = model_data . loaded_sd_models [ i ]
if loaded_model . sd_checkpoint_info . filename == checkpoint_info . filename :
already_loaded = loaded_model
continue
if len ( model_data . loaded_sd_models ) > shared . opts . sd_checkpoints_limit > 0 :
print ( f " Unloading model { len ( model_data . loaded_sd_models ) } over the limit of { shared . opts . sd_checkpoints_limit } : { loaded_model . sd_checkpoint_info . title } " )
2024-03-25 23:53:16 -06:00
del model_data . loaded_sd_models [ i ]
2023-07-31 15:24:48 -06:00
send_model_to_trash ( loaded_model )
timer . record ( " send model to trash " )
if already_loaded is not None :
send_model_to_device ( already_loaded )
timer . record ( " send model to device " )
2023-08-20 00:00:14 -06:00
model_data . set_sd_model ( already_loaded , already_loaded = True )
2023-08-10 08:04:59 -06:00
if not SkipWritingToConfig . skip :
shared . opts . data [ " sd_model_checkpoint " ] = already_loaded . sd_checkpoint_info . title
shared . opts . data [ " sd_checkpoint_hash " ] = already_loaded . sd_checkpoint_info . sha256
2023-07-31 15:24:48 -06:00
print ( f " Using already loaded model { already_loaded . sd_checkpoint_info . title } : done in { timer . summary ( ) } " )
2023-08-20 20:28:53 -06:00
sd_vae . reload_vae_weights ( already_loaded )
2023-07-31 15:24:48 -06:00
return model_data . sd_model
elif shared . opts . sd_checkpoints_limit > 1 and len ( model_data . loaded_sd_models ) < shared . opts . sd_checkpoints_limit :
print ( f " Loading model { checkpoint_info . title } ( { len ( model_data . loaded_sd_models ) + 1 } out of { shared . opts . sd_checkpoints_limit } ) " )
model_data . sd_model = None
load_model ( checkpoint_info )
return model_data . sd_model
elif len ( model_data . loaded_sd_models ) > 0 :
sd_model = model_data . loaded_sd_models . pop ( )
model_data . sd_model = sd_model
2023-08-20 04:44:37 -06:00
sd_vae . base_vae = getattr ( sd_model , " base_vae " , None )
sd_vae . loaded_vae_file = getattr ( sd_model , " loaded_vae_file " , None )
sd_vae . checkpoint_info = sd_model . sd_checkpoint_info
2023-08-20 00:00:14 -06:00
2023-07-31 15:24:48 -06:00
print ( f " Reusing loaded model { sd_model . sd_checkpoint_info . title } to load { checkpoint_info . title } " )
return sd_model
else :
return None
2023-11-19 00:50:06 -07:00
def reload_model_weights ( sd_model = None , info = None , forced_reload = False ) :
2022-09-17 04:49:36 -06:00
checkpoint_info = info or select_checkpoint ( )
2023-01-04 02:35:07 -07:00
2023-07-31 15:24:48 -06:00
timer = Timer ( )
2022-11-01 01:01:49 -06:00
if not sd_model :
2023-05-02 00:08:00 -06:00
sd_model = model_data . sd_model
2023-01-27 01:28:12 -07:00
2023-01-10 06:51:04 -07:00
if sd_model is None : # previous model load failed
2023-01-09 16:34:26 -07:00
current_checkpoint_info = None
else :
current_checkpoint_info = sd_model . sd_checkpoint_info
2023-11-19 00:50:06 -07:00
if check_fp8 ( sd_model ) != devices . fp8 :
# load from state dict again to prevent extra numerical errors
forced_reload = True
2023-12-06 00:16:10 -07:00
elif sd_model . sd_model_checkpoint == checkpoint_info . filename and not forced_reload :
2023-07-31 15:24:48 -06:00
return sd_model
2023-05-27 06:47:33 -06:00
2023-07-31 15:24:48 -06:00
sd_model = reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer )
2023-11-19 00:50:06 -07:00
if not forced_reload and sd_model is not None and sd_model . sd_checkpoint_info . filename == checkpoint_info . filename :
2023-07-31 15:24:48 -06:00
return sd_model
2022-09-17 03:05:04 -06:00
2023-07-31 15:24:48 -06:00
if sd_model is not None :
sd_unet . apply_unet ( " None " )
send_model_to_cpu ( sd_model )
2023-01-27 01:54:19 -07:00
sd_hijack . model_hijack . undo_hijack ( sd_model )
2022-09-29 06:40:28 -06:00
2023-01-27 01:28:12 -07:00
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer . record ( " find config " )
if sd_model is None or checkpoint_config != sd_model . used_config :
2023-07-31 15:24:48 -06:00
if sd_model is not None :
send_model_to_trash ( sd_model )
2023-03-08 21:56:19 -07:00
load_model ( checkpoint_info , already_loaded_state_dict = state_dict )
2023-05-02 00:08:00 -06:00
return model_data . sd_model
2023-01-27 01:28:12 -07:00
2023-01-04 02:35:07 -07:00
try :
2023-01-27 01:28:12 -07:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-05-09 22:52:45 -06:00
except Exception :
2023-01-04 02:35:07 -07:00
print ( " Failed to load checkpoint, restoring previous " )
2023-01-27 01:28:12 -07:00
load_model_weights ( sd_model , current_checkpoint_info , None , timer )
2023-01-04 02:35:07 -07:00
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 01:28:12 -07:00
timer . record ( " hijack " )
2023-08-22 09:49:08 -06:00
if not sd_model . lowvram :
2023-01-04 02:35:07 -07:00
sd_model . to ( devices . device )
2023-01-27 01:28:12 -07:00
timer . record ( " move model to device " )
2022-09-17 03:05:04 -06:00
2024-01-06 05:03:33 -07:00
script_callbacks . model_loaded_callback ( sd_model )
timer . record ( " script callbacks " )
2023-01-27 01:28:12 -07:00
print ( f " Weights loaded in { timer . summary ( ) } . " )
2023-01-04 02:35:07 -07:00
2023-07-31 15:24:48 -06:00
model_data . set_sd_model ( sd_model )
2023-08-06 23:16:20 -06:00
sd_unet . apply_unet ( )
2023-07-31 15:24:48 -06:00
2022-09-17 03:05:04 -06:00
return sd_model
2023-03-08 21:56:19 -07:00
2023-05-02 00:08:00 -06:00
2023-03-08 21:56:19 -07:00
def unload_model_weights ( sd_model = None , info = None ) :
2023-10-15 00:41:02 -06:00
send_model_to_cpu ( sd_model or shared . sd_model )
2023-03-08 21:56:19 -07:00
2023-04-04 01:26:44 -06:00
return sd_model
2023-05-17 11:22:38 -06:00
def apply_token_merging ( sd_model , token_merging_ratio ) :
2023-04-04 01:26:44 -06:00
"""
Applies speed and memory optimizations from tomesd .
"""
2023-05-17 11:22:38 -06:00
current_token_merging_ratio = getattr ( sd_model , ' applied_token_merged_ratio ' , 0 )
if current_token_merging_ratio == token_merging_ratio :
return
if current_token_merging_ratio > 0 :
tomesd . remove_patch ( sd_model )
if token_merging_ratio > 0 :
tomesd . apply_patch (
sd_model ,
ratio = token_merging_ratio ,
use_rand = False , # can cause issues with some samplers
merge_attn = True ,
merge_crossattn = False ,
merge_mlp = False
)
sd_model . applied_token_merged_ratio = token_merging_ratio