[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs (#167)
use inspect instead of unused kwargs
This commit is contained in:
parent
3228eb1609
commit
c72e343085
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -59,6 +60,12 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_kwrags = {}
|
||||
if not accepts_eta:
|
||||
extra_kwrags["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
if guidance_scale == 1.0:
|
||||
# guidance_scale of 1 means no guidance
|
||||
|
@ -79,7 +86,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
@ -31,11 +33,17 @@ class LDMPipeline(DiffusionPipeline):
|
|||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_kwrags = {}
|
||||
if not accepts_eta:
|
||||
extra_kwrags["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
|
||||
|
||||
# decode the image latents with the VAE
|
||||
image = self.vqvae.decode(latents)
|
||||
|
|
|
@ -116,7 +116,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
**kwargs,
|
||||
):
|
||||
if self.counter < len(self.prk_timesteps):
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
|
||||
|
|
Loading…
Reference in New Issue