Add torch_device to the VE pipeline

This commit is contained in:
anton-l 2022-07-21 13:53:09 +02:00
parent a73ae3e5b0
commit 7c0a861894
1 changed files with 7 additions and 6 deletions

View File

@ -11,22 +11,23 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
img_size = self.model.config.sample_size img_size = self.model.config.sample_size
shape = (1, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
model = self.model.to(device) model = self.model.to(torch_device)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(device) sample = sample.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):