make main model loading and model merger use the same code
This commit is contained in:
parent
050a6a798c
commit
c77c89cc83
|
@ -169,9 +169,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
theta_0 = primary_model['state_dict']
|
||||
theta_1 = secondary_model['state_dict']
|
||||
|
||||
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted Sum": weighted_sum,
|
||||
|
|
|
@ -122,6 +122,13 @@ def select_checkpoint():
|
|||
return checkpoint_info
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
if "state_dict" in pl_sd:
|
||||
return pl_sd["state_dict"]
|
||||
|
||||
return pl_sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
@ -131,11 +138,8 @@ def load_model_weights(model, checkpoint_info):
|
|||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue