Forgot to pass args to write_batch_schedule
This commit is contained in:
parent
f96d44ddb4
commit
c8c658d181
6
train.py
6
train.py
|
@ -329,7 +329,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
|
|||
|
||||
return image_train_items
|
||||
|
||||
def write_batch_schedule(log_folder, train_batch, epoch):
|
||||
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
||||
if args.write_schedule:
|
||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
for i in range(len(train_batch.image_train_items)):
|
||||
|
@ -783,7 +783,7 @@ def main(args):
|
|||
# # discard the grads, just want to pin memory
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
write_batch_schedule(log_folder, train_batch, 0)
|
||||
write_batch_schedule(args, log_folder, train_batch, 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
|
@ -935,7 +935,7 @@ def main(args):
|
|||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
write_batch_schedule(log_folder, train_batch, epoch + 1)
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
|
Loading…
Reference in New Issue