Update main.py

Simple solution to a common problem with state_dict
This commit is contained in:
tkgix 2022-12-10 22:10:41 +09:00 committed by GitHub
parent 25948bf770
commit aa24932c51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 2 deletions

View File

@ -29,7 +29,15 @@ def load_model_from_config(config, ckpt, verbose=False):
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps") print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps")
## sd = pl_sd["state_dict"]
if "state_dict" in pl_sd:
print("load_state_dict from state_dict")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else:
print("load_state_dict from directly")
sd = pl_sd
config.model.params.ckpt_path = ckpt config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)