Add CPU offloading to UnCLIP (#1761)

* Add CPU offloading to UnCLIP

* use fp32 for testing the offload
This commit is contained in:
Anton Lozhkov 2022-12-19 14:44:08 +01:00 committed by GitHub
parent be38b2d711
commit c7b4acfb37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 15 deletions

View File

@ -23,7 +23,7 @@ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import UnCLIPScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...utils import logging
from ...utils import is_accelerate_available, logging
from .text_proj import UnCLIPTextProjModel
@ -115,7 +115,7 @@ class UnCLIPPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma
return latents
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
@ -126,7 +126,7 @@ class UnCLIPPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(self.device)
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
@ -136,7 +136,7 @@ class UnCLIPPipeline(DiffusionPipeline):
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(self.device))
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
@ -156,8 +156,8 @@ class UnCLIPPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(self.device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device))
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
@ -187,6 +187,49 @@ class UnCLIPPipeline(DiffusionPipeline):
return text_embeddings, text_encoder_hidden_states, text_mask
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
models = [
self.decoder,
self.text_proj,
self.text_encoder,
self.super_res_first,
self.super_res_last,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
return self.device
for module in self.decoder.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
def __call__(
self,
@ -254,25 +297,26 @@ class UnCLIPPipeline(DiffusionPipeline):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)
# prior
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device)
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
prior_timesteps_tensor = self.prior_scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
self.device,
device,
generator,
prior_latents,
self.prior_scheduler,
@ -326,7 +370,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
@ -335,7 +379,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
self.device,
device,
generator,
decoder_latents,
self.decoder_scheduler,
@ -378,7 +422,7 @@ class UnCLIPPipeline(DiffusionPipeline):
# super res
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device)
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
@ -387,7 +431,7 @@ class UnCLIPPipeline(DiffusionPipeline):
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
self.device,
device,
generator,
super_res_latents,
self.super_res_scheduler,

View File

@ -261,10 +261,10 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
def test_unclip_karlo(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/unclip/karlo_v1_alpha_horse.npy"
"/unclip/karlo_v1_alpha_horse_fp16.npy"
)
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
@ -280,3 +280,29 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
"horse",
num_images_per_prompt=1,
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 1.5 GB is allocated
assert mem_bytes < 1.5 * 10**9