no inference moed doesn't always work

This commit is contained in:
Patrick von Platen 2022-06-28 23:05:08 +00:00
parent 740326d2a2
commit e47c97a451
1 changed files with 1 additions and 1 deletions

View File

@ -161,7 +161,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste
# 1. predict noise residual
orig_t = len(noise_scheduler) // num_inference_steps * t
with torch.inference_mode():
with torch.no_grad():
residual = unet(image, orig_t)
# 2. predict previous mean of image x_t-1