ddpm changes

This commit is contained in:
JoePenna 2022-09-25 22:23:08 -07:00
parent 225bec8b3f
commit 3d97f68047
2 changed files with 21 additions and 20 deletions

View File

@ -1318,7 +1318,7 @@ class LatentDiffusion(DDPM):
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
plot_diffusion_rows=False, **kwargs):
@ -1333,22 +1333,22 @@ class LatentDiffusion(DDPM):
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
# log["inputs"] = x
# log["reconstruction"] = xrec
# if self.model.conditioning_key is not None:
# if hasattr(self.cond_stage_model, "decode"):
# xc = self.cond_stage_model.decode(c)
# log["conditioning"] = xc
# elif self.cond_stage_key in ["caption"]:
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
# log["conditioning"] = xc
# elif self.cond_stage_key == 'class_label':
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
# log['conditioning'] = xc
# elif isimage(xc):
# log["conditioning"] = xc
# if ismap(xc):
# log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
@ -1388,7 +1388,7 @@ class LatentDiffusion(DDPM):
eta=ddim_eta,
unconditional_guidance_scale=5.0,
unconditional_conditioning=uc)
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
log["samples_subject"] = self.decode_first_stage(sample_scaled)
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):

View File

@ -399,7 +399,7 @@ class ImageLogger(Callback):
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.jpg".format(
filename = "{}_globalstep-{:05}_epoch-{:01}_batch-{:04}.jpg".format(
k,
global_step,
current_epoch,
@ -806,7 +806,7 @@ if __name__ == "__main__":
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
print("Here comes the checkpoint...")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
@ -851,4 +851,5 @@ if __name__ == "__main__":
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
if trainer.global_rank == 0:
print("Another one bites the dust...")
print(trainer.profiler.summary())