Add CPU offloading to UnCLIP (#1761)
* Add CPU offloading to UnCLIP * use fp32 for testing the offload
This commit is contained in:
parent
be38b2d711
commit
c7b4acfb37
|
@ -23,7 +23,7 @@ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|||
from diffusers.schedulers import UnCLIPScheduler
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
|
@ -115,7 +115,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
|
@ -126,7 +126,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(self.device)
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
|
@ -136,7 +136,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(self.device))
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
@ -156,8 +156,8 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(self.device)
|
||||
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device))
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
|
||||
|
@ -187,6 +187,49 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
|
||||
return text_embeddings, text_encoder_hidden_states, text_mask
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
|
||||
models = [
|
||||
self.decoder,
|
||||
self.text_proj,
|
||||
self.text_encoder,
|
||||
self.super_res_first,
|
||||
self.super_res_last,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.decoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -254,25 +297,26 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, num_images_per_prompt, do_classifier_free_guidance
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# prior
|
||||
|
||||
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device)
|
||||
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
|
||||
prior_timesteps_tensor = self.prior_scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
prior_latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
text_embeddings.dtype,
|
||||
self.device,
|
||||
device,
|
||||
generator,
|
||||
prior_latents,
|
||||
self.prior_scheduler,
|
||||
|
@ -326,7 +370,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
|
||||
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device)
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
|
@ -335,7 +379,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
self.device,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
|
@ -378,7 +422,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
|
||||
# super res
|
||||
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device)
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
|
@ -387,7 +431,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
self.device,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
|
|
|
@ -261,10 +261,10 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_unclip_karlo(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/unclip/karlo_v1_alpha_horse.npy"
|
||||
"/unclip/karlo_v1_alpha_horse_fp16.npy"
|
||||
)
|
||||
|
||||
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
|
||||
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -280,3 +280,29 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
|||
|
||||
assert image.shape == (256, 256, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
"horse",
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
prior_num_inference_steps=2,
|
||||
decoder_num_inference_steps=2,
|
||||
super_res_num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 1.5 GB is allocated
|
||||
assert mem_bytes < 1.5 * 10**9
|
||||
|
|
Loading…
Reference in New Issue