Add torch_device to the VE pipeline
This commit is contained in:
parent
a73ae3e5b0
commit
7c0a861894
|
@ -11,22 +11,23 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
self.register_modules(model=model, scheduler=scheduler)
|
self.register_modules(model=model, scheduler=scheduler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
if torch_device is None:
|
||||||
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
img_size = self.model.config.sample_size
|
img_size = self.model.config.sample_size
|
||||||
shape = (1, 3, img_size, img_size)
|
shape = (batch_size, 3, img_size, img_size)
|
||||||
|
|
||||||
model = self.model.to(device)
|
model = self.model.to(torch_device)
|
||||||
|
|
||||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||||
sample = sample.to(device)
|
sample = sample.to(torch_device)
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
self.scheduler.set_sigmas(num_inference_steps)
|
self.scheduler.set_sigmas(num_inference_steps)
|
||||||
|
|
||||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
|
||||||
|
|
||||||
# correction step
|
# correction step
|
||||||
for _ in range(self.scheduler.correct_steps):
|
for _ in range(self.scheduler.correct_steps):
|
||||||
|
|
Loading…
Reference in New Issue