Use underscore naming for "private" functions in sd_vae
This commit is contained in:
parent
9fdc343dca
commit
028b67b635
|
@ -55,7 +55,7 @@ def restore_base_vae(model):
|
||||||
global loaded_vae_file
|
global loaded_vae_file
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||||
print("Restoring base VAE")
|
print("Restoring base VAE")
|
||||||
load_vae_dict(model, base_vae)
|
_load_vae_dict(model, base_vae)
|
||||||
loaded_vae_file = None
|
loaded_vae_file = None
|
||||||
delete_base_vae()
|
delete_base_vae()
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ def load_vae(model, vae_file=None):
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
load_vae_dict(model, vae_dict_1)
|
_load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
# If vae used is not in dict, update it
|
# If vae used is not in dict, update it
|
||||||
# It will be removed on refresh though
|
# It will be removed on refresh though
|
||||||
|
@ -164,7 +164,7 @@ def load_vae(model, vae_file=None):
|
||||||
|
|
||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
def load_vae_dict(model, vae_dict_1):
|
def _load_vae_dict(model, vae_dict_1):
|
||||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue