fix wandb logging

This commit is contained in:
harubaru 2022-10-15 11:20:32 -07:00
parent f980137430
commit 18ad704fd6
1 changed files with 5 additions and 4 deletions

View File

@ -313,9 +313,10 @@ class ImageLogger(Callback):
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}" tag = f"{split}/{k}"
pl_module.logger.experiment.add_image( pl_module.logger.experiment.log(
tag, grid, {'tag': tag, 'examples': grid},
global_step=pl_module.global_step) step=pl_module.global_step
)
@rank_zero_only @rank_zero_only
def log_local(self, save_dir, split, images, def log_local(self, save_dir, split, images,
@ -357,7 +358,7 @@ class ImageLogger(Callback):
N = min(images[k].shape[0], self.max_images) N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N] images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor): if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu() images[k] = images[k].detach().cpu().to(torch.float32)
if self.clamp: if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.) images[k] = torch.clamp(images[k], -1., 1.)