Add torch_device to the VE pipeline
This commit is contained in:
parent
a73ae3e5b0
commit
7c0a861894
|
@ -11,22 +11,23 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
img_size = self.model.config.sample_size
|
||||
shape = (1, 3, img_size, img_size)
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.model.to(device)
|
||||
model = self.model.to(torch_device)
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(device)
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
|
||||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
|
|
Loading…
Reference in New Issue