Merge pull request #6782 from aria1th/fix-hypernetwork-loss

Fix tensorboard-hypernetwork integration correctly
This commit is contained in:
AUTOMATIC1111 2023-01-15 22:55:06 +03:00 committed by GitHub
commit d6fa8e92ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 6 deletions

View File

@ -561,6 +561,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0 #internal _loss_step = 0 #internal
# size = len(ds.indexes) # size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024)) # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
# losses = torch.zeros((size,)) # losses = torch.zeros((size,))
# previous_mean_losses = [0] # previous_mean_losses = [0]
# previous_mean_loss = 0 # previous_mean_loss = 0
@ -610,7 +611,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
loss_logging.append(_loss_step)
if clip_grad: if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate) clip_grad(weights, clip_grad_sched.learn_rate)
@ -644,7 +645,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.opts.training_enable_tensorboard: if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds) epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) mean_loss = sum(loss_logging) / len(loss_logging)
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num) textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
@ -688,9 +689,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None image = processed.images[0] if len(processed.images) > 0 else None
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
@ -701,7 +699,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
hypernetwork.train() hypernetwork.train()
if image is not None: if image is not None:
shared.state.assign_current_image(image) shared.state.assign_current_image(image)
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
textual_inversion.tensorboard_add_image(tensorboard_writer,
f"Validation at epoch {epoch_num}", image,
hypernetwork.step)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"