Update main.py
Simple solution to a common problem with state_dict
This commit is contained in:
parent
25948bf770
commit
aa24932c51
12
main.py
12
main.py
|
@ -29,7 +29,15 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps")
|
print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps")
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
|
## sd = pl_sd["state_dict"]
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
print("load_state_dict from state_dict")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
else:
|
||||||
|
print("load_state_dict from directly")
|
||||||
|
sd = pl_sd
|
||||||
|
|
||||||
config.model.params.ckpt_path = ckpt
|
config.model.params.ckpt_path = ckpt
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
@ -774,4 +782,4 @@ if __name__ == "__main__":
|
||||||
os.rename(logdir, dst)
|
os.rename(logdir, dst)
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
print("Training complete. max_steps or max_epochs reached, or we blew up.")
|
print("Training complete. max_steps or max_epochs reached, or we blew up.")
|
||||||
print(trainer.profiler.summary())
|
print(trainer.profiler.summary())
|
||||||
|
|
Loading…
Reference in New Issue