From 18ad704fd6fde7ea4ae70e2c035e91605b1a57c1 Mon Sep 17 00:00:00 2001 From: harubaru Date: Sat, 15 Oct 2022 11:20:32 -0700 Subject: [PATCH] fix wandb logging --- main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index fe1701d..3774110 100644 --- a/main.py +++ b/main.py @@ -313,9 +313,10 @@ class ImageLogger(Callback): grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" - pl_module.logger.experiment.add_image( - tag, grid, - global_step=pl_module.global_step) + pl_module.logger.experiment.log( + {'tag': tag, 'examples': grid}, + step=pl_module.global_step + ) @rank_zero_only def log_local(self, save_dir, split, images, @@ -357,7 +358,7 @@ class ImageLogger(Callback): N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): - images[k] = images[k].detach().cpu() + images[k] = images[k].detach().cpu().to(torch.float32) if self.clamp: images[k] = torch.clamp(images[k], -1., 1.)