Forgot to pass args to write_batch_schedule

This commit is contained in:
Joel Holdbrooks 2023-01-29 18:28:07 -08:00
parent f96d44ddb4
commit c8c658d181
1 changed files with 3 additions and 3 deletions

View File

@ -329,7 +329,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
return image_train_items
def write_batch_schedule(log_folder, train_batch, epoch):
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
if args.write_schedule:
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(train_batch.image_train_items)):
@ -783,7 +783,7 @@ def main(args):
# # discard the grads, just want to pin memory
# optimizer.zero_grad(set_to_none=True)
write_batch_schedule(log_folder, train_batch, 0)
write_batch_schedule(args, log_folder, train_batch, 0)
for epoch in range(args.max_epochs):
loss_epoch = []
@ -935,7 +935,7 @@ def main(args):
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
write_batch_schedule(log_folder, train_batch, epoch + 1)
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)