MPS schedulers: don't use float64 (#1169)

* Schedulers: don't use float64 on mps

* Test set_timesteps() on device (float schedulers).

* SD pipeline: use device in set_timesteps.

* SD in-painting pipeline: use device in set_timesteps.

* Tests: fix mps crashes.

* Skip test_load_pipeline_from_git on mps.

Not compatible with float16.

* Use device.type instead of str in Euler schedulers.
This commit is contained in:
Pedro Cuenca 2022-11-08 13:11:33 +01:00 committed by GitHub
parent 5a8b356922
commit 813744e5f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 117 additions and 22 deletions

View File

@ -360,12 +360,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

View File

@ -416,12 +416,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" `pipeline.unet` or your `mask_image` or `image` input."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

View File

@ -151,6 +151,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step(
@ -217,8 +221,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device

View File

@ -152,6 +152,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step(
@ -214,8 +218,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device

View File

@ -173,7 +173,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = []

View File

@ -456,6 +456,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
@ -507,6 +508,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
@ -558,6 +560,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))

View File

@ -41,7 +41,7 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -124,7 +124,7 @@ class CustomPipelineTests(unittest.TestCase):
assert output_str == "This is a local test"
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
@require_torch_gpu
def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"

View File

@ -83,8 +83,8 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)
@ -1192,6 +1192,31 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3
class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerDiscreteScheduler,)
@ -1248,6 +1273,34 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator().manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerAncestralDiscreteScheduler,)
@ -1303,6 +1356,38 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator().manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
if not str(torch_device).startswith("mps"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
# Larger tolerance on mps
assert abs(result_mean.item() - 0.1983) < 1e-2
class IPNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (IPNDMScheduler,)