use selected device instead of always cuda for UniPC sampler
This commit is contained in:
parent
a11ce2b96c
commit
f261a4a53c
|
@ -3,7 +3,8 @@
|
|||
import torch
|
||||
|
||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||
from modules import shared
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
class UniPCSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
|
@ -16,8 +17,8 @@ class UniPCSampler(object):
|
|||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(devices.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def set_hooks(self, before_sample, after_sample, after_update):
|
||||
|
|
Loading…
Reference in New Issue