handle device for randn in euler step (#1124)
* handle device for randn in euler step * convert device to str
This commit is contained in:
parent
42bb459457
commit
7b030a7d68
|
@ -217,7 +217,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
prev_sample = sample + derivative * dt
|
||||
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||
if str(device) == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
||||
device
|
||||
)
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
@ -214,7 +214,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||
if str(device) == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
||||
device
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue