Match hypernet name with filename in all cases.
This commit is contained in:
parent
51e3dc9cca
commit
19818f023c
|
@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
pbar.set_description(f"loss: {mean_loss:.7f}")
|
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
temp = hypernetwork.name
|
||||||
|
# Before saving, change name to match current checkpoint.
|
||||||
|
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
|
@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
|
||||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||||
|
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
|
||||||
|
hypernetwork.name = hypernetwork_name
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||||
hypernetwork.save(filename)
|
hypernetwork.save(filename)
|
||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
Loading…
Reference in New Issue