Fixing log_img blowing up.

This commit is contained in:
David Bielejeski 2022-09-29 18:35:40 -05:00
parent 6158df3142
commit 40eb005c3d
2 changed files with 13 additions and 13 deletions

20
main.py
View File

@ -462,15 +462,17 @@ class ImageLogger(Callback):
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
pass
#if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
#self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, 'calibrate_grad_norm'):
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
pass
#if not self.disabled and pl_module.global_step > 0:
#self.log_img(pl_module, batch, batch_idx, split="val")
#if hasattr(pl_module, 'calibrate_grad_norm'):
#if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
#self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
@ -866,5 +868,5 @@ if __name__ == "__main__":
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
if trainer.global_rank == 0:
print("Another one bites the dust...")
print(trainer.profiler.summary())
print("Training complete. max_training_steps reached or we blew up.")
# print(trainer.profiler.summary())

View File

@ -416,9 +416,7 @@
},
{
"cell_type": "code",
"source": [
""
],
"source": [],
"metadata": {
"id": "92QkRfm0e6K0"
},