fix wandb logging
This commit is contained in:
parent
f980137430
commit
18ad704fd6
9
main.py
9
main.py
|
@ -313,9 +313,10 @@ class ImageLogger(Callback):
|
|||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f"{split}/{k}"
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid,
|
||||
global_step=pl_module.global_step)
|
||||
pl_module.logger.experiment.log(
|
||||
{'tag': tag, 'examples': grid},
|
||||
step=pl_module.global_step
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images,
|
||||
|
@ -357,7 +358,7 @@ class ImageLogger(Callback):
|
|||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
images[k] = images[k].detach().cpu().to(torch.float32)
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1., 1.)
|
||||
|
||||
|
|
Loading…
Reference in New Issue