diff --git a/main.py b/main.py index fe1701d..3774110 100644 --- a/main.py +++ b/main.py @@ -313,9 +313,10 @@ class ImageLogger(Callback): grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" - pl_module.logger.experiment.add_image( - tag, grid, - global_step=pl_module.global_step) + pl_module.logger.experiment.log( + {'tag': tag, 'examples': grid}, + step=pl_module.global_step + ) @rank_zero_only def log_local(self, save_dir, split, images, @@ -357,7 +358,7 @@ class ImageLogger(Callback): N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): - images[k] = images[k].detach().cpu() + images[k] = images[k].detach().cpu().to(torch.float32) if self.clamp: images[k] = torch.clamp(images[k], -1., 1.)