Forgot to pass args to write_batch_schedule
This commit is contained in:
parent
f96d44ddb4
commit
c8c658d181
6
train.py
6
train.py
|
@ -329,7 +329,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
|
||||||
|
|
||||||
return image_train_items
|
return image_train_items
|
||||||
|
|
||||||
def write_batch_schedule(log_folder, train_batch, epoch):
|
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
||||||
if args.write_schedule:
|
if args.write_schedule:
|
||||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||||
for i in range(len(train_batch.image_train_items)):
|
for i in range(len(train_batch.image_train_items)):
|
||||||
|
@ -783,7 +783,7 @@ def main(args):
|
||||||
# # discard the grads, just want to pin memory
|
# # discard the grads, just want to pin memory
|
||||||
# optimizer.zero_grad(set_to_none=True)
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
write_batch_schedule(log_folder, train_batch, 0)
|
write_batch_schedule(args, log_folder, train_batch, 0)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
|
@ -935,7 +935,7 @@ def main(args):
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if epoch < args.max_epochs - 1:
|
if epoch < args.max_epochs - 1:
|
||||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||||
write_batch_schedule(log_folder, train_batch, epoch + 1)
|
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||||
|
|
||||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
|
Loading…
Reference in New Issue