From 1172c9634b4a32d6e82301e3d59ce17005e13e85 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Fri, 4 Nov 2022 15:25:28 -0300 Subject: [PATCH] add enable sequential cpu offloading to other stable diffusion pipelines (#1085) * add enable sequential cpu offloading to other stable diffusion pipelines * trigger ci * fix styling * interpolate before converting to device to avoid breking when cpu_offload is enabled with fp16 Co-authored-by: Pedro Gengo * style again I need to stop forgething this thing * fix inpainting bug that could cause device misalignment Co-authored-by: Pedro Gengo * Apply suggestions from code review Co-authored-by: Pedro Gengo Co-authored-by: Patrick von Platen --- .../pipeline_stable_diffusion_img2img.py | 18 +++++++ .../pipeline_stable_diffusion_inpaint.py | 28 ++++++++++- .../stable_diffusion/test_stable_diffusion.py | 2 +- .../test_stable_diffusion_img2img.py | 45 ++++++++++++++++++ .../test_stable_diffusion_inpaint.py | 47 ++++++++++++++++++- 5 files changed, 136 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 8284bac8..08b14b36 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -5,6 +5,7 @@ import numpy as np import torch import PIL +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -151,6 +152,23 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c200892e..34e8231c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -5,6 +5,7 @@ import numpy as np import torch import PIL +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -151,6 +152,23 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -361,11 +379,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - mask = mask.to(device=self.device, dtype=text_embeddings.dtype) - masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) + mask = mask.to(device=self.device, dtype=text_embeddings.dtype) + + masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) # encode the mask image into latents space so we can concatenate it to the latents masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) @@ -380,6 +401,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype) + num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 89fac46e..a83299ea 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -840,7 +840,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): assert 2 * low_cpu_mem_usage_time < normal_load_time @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") - def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index ca8bc191..2d29e1b8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -599,3 +599,48 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ) assert test_callback_fn.has_been_called assert number_of_steps == 38 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape_k_lms.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "CompVis/stable-diffusion-v1-4" + lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=None, + device_map="auto", + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + num_inference_steps=5, + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 1.5 GB is allocated + assert mem_bytes < 1.5 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 44a7a324..e8dcb431 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -378,4 +378,49 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): image = output.images[0] assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-3 + assert np.abs(expected_image - image).max() < 1e-2 + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "runwayml/stable-diffusion-inpainting" + pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, safety_checker=None, scheduler=pndm, device_map="auto" + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 1.5 GB is allocated + assert mem_bytes < 1.5 * 10**9