add enable sequential cpu offloading to other stable diffusion pipelines (#1085)
* add enable sequential cpu offloading to other stable diffusion pipelines * trigger ci * fix styling * interpolate before converting to device to avoid breking when cpu_offload is enabled with fp16 Co-authored-by: Pedro Gengo <pedro.gabriel.lourenco@hotmail.com> * style again I need to stop forgething this thing * fix inpainting bug that could cause device misalignment Co-authored-by: Pedro Gengo <pedro.gabriel.lourenco@hotmail.com> * Apply suggestions from code review Co-authored-by: Pedro Gengo <pedro.gabriel.lourenco@hotmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
2fcae69f2a
commit
1172c9634b
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
@ -151,6 +152,23 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
@ -151,6 +152,23 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
@ -361,11 +379,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
|
||||
# prepare mask and masked_image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
|
@ -380,6 +401,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
|
||||
|
|
|
@ -840,7 +840,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
assert 2 * low_cpu_mem_usage_time < normal_load_time
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
|
|
|
@ -599,3 +599,48 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
|||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/fantasy_landscape_k_lms.png"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=5,
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 1.5 GB is allocated
|
||||
assert mem_bytes < 1.5 * 10**9
|
||||
|
|
|
@ -378,4 +378,49 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
|||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 1.5 GB is allocated
|
||||
assert mem_bytes < 1.5 * 10**9
|
||||
|
|
Loading…
Reference in New Issue