ddpm changes
This commit is contained in:
parent
225bec8b3f
commit
3d97f68047
|
@ -1318,7 +1318,7 @@ class LatentDiffusion(DDPM):
|
|||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=1., return_keys=None,
|
||||
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
|
||||
plot_diffusion_rows=False, **kwargs):
|
||||
|
||||
|
@ -1333,22 +1333,22 @@ class LatentDiffusion(DDPM):
|
|||
bs=N)
|
||||
N = min(x.shape[0], N)
|
||||
n_row = min(x.shape[0], n_row)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
if self.model.conditioning_key is not None:
|
||||
if hasattr(self.cond_stage_model, "decode"):
|
||||
xc = self.cond_stage_model.decode(c)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key in ["caption"]:
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key == 'class_label':
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log['conditioning'] = xc
|
||||
elif isimage(xc):
|
||||
log["conditioning"] = xc
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = self.to_rgb(xc)
|
||||
# log["inputs"] = x
|
||||
# log["reconstruction"] = xrec
|
||||
# if self.model.conditioning_key is not None:
|
||||
# if hasattr(self.cond_stage_model, "decode"):
|
||||
# xc = self.cond_stage_model.decode(c)
|
||||
# log["conditioning"] = xc
|
||||
# elif self.cond_stage_key in ["caption"]:
|
||||
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
||||
# log["conditioning"] = xc
|
||||
# elif self.cond_stage_key == 'class_label':
|
||||
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
# log['conditioning'] = xc
|
||||
# elif isimage(xc):
|
||||
# log["conditioning"] = xc
|
||||
# if ismap(xc):
|
||||
# log["original_conditioning"] = self.to_rgb(xc)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
|
@ -1388,7 +1388,7 @@ class LatentDiffusion(DDPM):
|
|||
eta=ddim_eta,
|
||||
unconditional_guidance_scale=5.0,
|
||||
unconditional_conditioning=uc)
|
||||
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
|
||||
log["samples_subject"] = self.decode_first_stage(sample_scaled)
|
||||
|
||||
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
||||
self.first_stage_model, IdentityFirstStage):
|
||||
|
|
5
main.py
5
main.py
|
@ -399,7 +399,7 @@ class ImageLogger(Callback):
|
|||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.jpg".format(
|
||||
filename = "{}_globalstep-{:05}_epoch-{:01}_batch-{:04}.jpg".format(
|
||||
k,
|
||||
global_step,
|
||||
current_epoch,
|
||||
|
@ -806,7 +806,7 @@ if __name__ == "__main__":
|
|||
def melk(*args, **kwargs):
|
||||
# run all checkpoint hooks
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
print("Here comes the checkpoint...")
|
||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
@ -851,4 +851,5 @@ if __name__ == "__main__":
|
|||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
os.rename(logdir, dst)
|
||||
if trainer.global_rank == 0:
|
||||
print("Another one bites the dust...")
|
||||
print(trainer.profiler.summary())
|
Loading…
Reference in New Issue