Fixing log_img blowing up.

This commit is contained in:
David Bielejeski 2022-09-29 18:35:40 -05:00
parent 6158df3142
commit 40eb005c3d
2 changed files with 13 additions and 13 deletions

20
main.py
View File

@ -462,15 +462,17 @@ class ImageLogger(Callback):
return False return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): pass
self.log_img(pl_module, batch, batch_idx, split="train") #if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
#self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled and pl_module.global_step > 0: pass
self.log_img(pl_module, batch, batch_idx, split="val") #if not self.disabled and pl_module.global_step > 0:
if hasattr(pl_module, 'calibrate_grad_norm'): #self.log_img(pl_module, batch, batch_idx, split="val")
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: #if hasattr(pl_module, 'calibrate_grad_norm'):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx) #if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
#self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback): class CUDACallback(Callback):
@ -866,5 +868,5 @@ if __name__ == "__main__":
os.makedirs(os.path.split(dst)[0], exist_ok=True) os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst) os.rename(logdir, dst)
if trainer.global_rank == 0: if trainer.global_rank == 0:
print("Another one bites the dust...") print("Training complete. max_training_steps reached or we blew up.")
print(trainer.profiler.summary()) # print(trainer.profiler.summary())

View File

@ -416,9 +416,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [],
""
],
"metadata": { "metadata": {
"id": "92QkRfm0e6K0" "id": "92QkRfm0e6K0"
}, },