diff --git a/train.py b/train.py index af0d1ee..13de14f 100644 --- a/train.py +++ b/train.py @@ -329,7 +329,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list return image_train_items -def write_batch_schedule(log_folder, train_batch, epoch): +def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int): if args.write_schedule: with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: for i in range(len(train_batch.image_train_items)): @@ -783,7 +783,7 @@ def main(args): # # discard the grads, just want to pin memory # optimizer.zero_grad(set_to_none=True) - write_batch_schedule(log_folder, train_batch, 0) + write_batch_schedule(args, log_folder, train_batch, 0) for epoch in range(args.max_epochs): loss_epoch = [] @@ -935,7 +935,7 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) - write_batch_schedule(log_folder, train_batch, epoch + 1) + write_batch_schedule(args, log_folder, train_batch, epoch + 1) loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)