Use underscore naming for "private" functions in sd_vae
This commit is contained in:
parent
9fdc343dca
commit
028b67b635
|
@ -55,7 +55,7 @@ def restore_base_vae(model):
|
|||
global loaded_vae_file
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
print("Restoring base VAE")
|
||||
load_vae_dict(model, base_vae)
|
||||
_load_vae_dict(model, base_vae)
|
||||
loaded_vae_file = None
|
||||
delete_base_vae()
|
||||
|
||||
|
@ -147,7 +147,7 @@ def load_vae(model, vae_file=None):
|
|||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
load_vae_dict(model, vae_dict_1)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
|
@ -164,7 +164,7 @@ def load_vae(model, vae_file=None):
|
|||
|
||||
|
||||
# don't call this from outside
|
||||
def load_vae_dict(model, vae_dict_1):
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
|
Loading…
Reference in New Issue