handle device for randn in euler step (#1124)

* handle device for randn in euler step

* convert device to str
This commit is contained in:
Suraj Patil 2022-11-03 18:13:18 +01:00 committed by GitHub
parent 42bb459457
commit 7b030a7d68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 2 deletions

View File

@ -217,7 +217,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
prev_sample = prev_sample + noise * sigma_up
if not return_dict:

View File

@ -214,7 +214,16 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)