fix wandb logging
This commit is contained in:
parent
f980137430
commit
18ad704fd6
9
main.py
9
main.py
|
@ -313,9 +313,10 @@ class ImageLogger(Callback):
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||||
|
|
||||||
tag = f"{split}/{k}"
|
tag = f"{split}/{k}"
|
||||||
pl_module.logger.experiment.add_image(
|
pl_module.logger.experiment.log(
|
||||||
tag, grid,
|
{'tag': tag, 'examples': grid},
|
||||||
global_step=pl_module.global_step)
|
step=pl_module.global_step
|
||||||
|
)
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def log_local(self, save_dir, split, images,
|
def log_local(self, save_dir, split, images,
|
||||||
|
@ -357,7 +358,7 @@ class ImageLogger(Callback):
|
||||||
N = min(images[k].shape[0], self.max_images)
|
N = min(images[k].shape[0], self.max_images)
|
||||||
images[k] = images[k][:N]
|
images[k] = images[k][:N]
|
||||||
if isinstance(images[k], torch.Tensor):
|
if isinstance(images[k], torch.Tensor):
|
||||||
images[k] = images[k].detach().cpu()
|
images[k] = images[k].detach().cpu().to(torch.float32)
|
||||||
if self.clamp:
|
if self.clamp:
|
||||||
images[k] = torch.clamp(images[k], -1., 1.)
|
images[k] = torch.clamp(images[k], -1., 1.)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue