diff --git a/main.py b/main.py index 55b6ddc..5c64b7a 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,15 @@ def load_model_from_config(config, ckpt, verbose=False): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps") - sd = pl_sd["state_dict"] + + ## sd = pl_sd["state_dict"] + if "state_dict" in pl_sd: + print("load_state_dict from state_dict") + sd = pl_sd["state_dict"] + else: + print("load_state_dict from directly") + sd = pl_sd + config.model.params.ckpt_path = ckpt model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) @@ -774,4 +782,4 @@ if __name__ == "__main__": os.rename(logdir, dst) if trainer.global_rank == 0: print("Training complete. max_steps or max_epochs reached, or we blew up.") - print(trainer.profiler.summary()) \ No newline at end of file + print(trainer.profiler.summary())