476 lines
22 KiB
Python
Executable File
476 lines
22 KiB
Python
Executable File
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib
|
|
import warnings
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import torch
|
|
|
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
|
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
from diffusers.utils import is_accelerate_available, logging
|
|
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class ModelWrapper:
|
|
def __init__(self, model, alphas_cumprod):
|
|
self.model = model
|
|
self.alphas_cumprod = alphas_cumprod
|
|
|
|
def apply_model(self, *args, **kwargs):
|
|
if len(args) == 3:
|
|
encoder_hidden_states = args[-1]
|
|
args = args[:2]
|
|
if kwargs.get("cond", None) is not None:
|
|
encoder_hidden_states = kwargs.pop("cond")
|
|
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
|
|
|
|
|
|
class StableDiffusionPipeline(DiffusionPipeline):
|
|
r"""
|
|
Pipeline for text-to-image generation using Stable Diffusion.
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
|
Args:
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
text_encoder ([`CLIPTextModel`]):
|
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
tokenizer (`CLIPTokenizer`):
|
|
Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
scheduler ([`SchedulerMixin`]):
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
|
"""
|
|
_optional_components = ["safety_checker", "feature_extractor"]
|
|
|
|
def __init__(
|
|
self,
|
|
vae,
|
|
text_encoder,
|
|
tokenizer,
|
|
unet,
|
|
scheduler,
|
|
safety_checker,
|
|
feature_extractor,
|
|
):
|
|
super().__init__()
|
|
|
|
if safety_checker is None:
|
|
logger.warning(
|
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
|
)
|
|
|
|
# get correct sigmas from LMS
|
|
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
safety_checker=safety_checker,
|
|
feature_extractor=feature_extractor,
|
|
)
|
|
|
|
model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
|
if scheduler.prediction_type == "v_prediction":
|
|
self.k_diffusion_model = CompVisVDenoiser(model)
|
|
else:
|
|
self.k_diffusion_model = CompVisDenoiser(model)
|
|
|
|
def set_sampler(self, scheduler_type: str):
|
|
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
|
|
return self.set_scheduler(scheduler_type)
|
|
|
|
def set_scheduler(self, scheduler_type: str):
|
|
library = importlib.import_module("k_diffusion")
|
|
sampling = getattr(library, "sampling")
|
|
self.sampler = getattr(sampling, scheduler_type)
|
|
|
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
|
r"""
|
|
Enable sliced attention computation.
|
|
|
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
|
|
|
Args:
|
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
|
`attention_head_dim` must be a multiple of `slice_size`.
|
|
"""
|
|
if slice_size == "auto":
|
|
# half the attention head size is usually a good trade-off between
|
|
# speed and memory
|
|
slice_size = self.unet.config.attention_head_dim // 2
|
|
self.unet.set_attention_slice(slice_size)
|
|
|
|
def disable_attention_slicing(self):
|
|
r"""
|
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
|
back to computing attention in one step.
|
|
"""
|
|
# set slice_size = `None` to disable `attention slicing`
|
|
self.enable_attention_slicing(None)
|
|
|
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
|
r"""
|
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
|
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
|
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
|
"""
|
|
if is_accelerate_available():
|
|
from accelerate import cpu_offload
|
|
else:
|
|
raise ImportError("Please install accelerate via `pip install accelerate`")
|
|
|
|
device = torch.device(f"cuda:{gpu_id}")
|
|
|
|
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
|
if cpu_offloaded_model is not None:
|
|
cpu_offload(cpu_offloaded_model, device)
|
|
|
|
@property
|
|
def _execution_device(self):
|
|
r"""
|
|
Returns the device on which the pipeline's models will be executed. After calling
|
|
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
|
hooks.
|
|
"""
|
|
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
|
return self.device
|
|
for module in self.unet.modules():
|
|
if (
|
|
hasattr(module, "_hf_hook")
|
|
and hasattr(module._hf_hook, "execution_device")
|
|
and module._hf_hook.execution_device is not None
|
|
):
|
|
return torch.device(module._hf_hook.execution_device)
|
|
return self.device
|
|
|
|
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
|
r"""
|
|
Encodes the prompt into text encoder hidden states.
|
|
|
|
Args:
|
|
prompt (`str` or `list(int)`):
|
|
prompt to be encoded
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
do_classifier_free_guidance (`bool`):
|
|
whether to use classifier free guidance or not
|
|
negative_prompt (`str` or `List[str]`):
|
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
|
if `guidance_scale` is less than `1`).
|
|
"""
|
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
|
|
|
text_inputs = self.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=self.tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
|
|
|
if not torch.equal(text_input_ids, untruncated_ids):
|
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
|
attention_mask = text_inputs.attention_mask.to(device)
|
|
else:
|
|
attention_mask = None
|
|
|
|
text_embeddings = self.text_encoder(
|
|
text_input_ids.to(device),
|
|
attention_mask=attention_mask,
|
|
)
|
|
text_embeddings = text_embeddings[0]
|
|
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
bs_embed, seq_len, _ = text_embeddings.shape
|
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
if do_classifier_free_guidance:
|
|
uncond_tokens: List[str]
|
|
if negative_prompt is None:
|
|
uncond_tokens = [""] * batch_size
|
|
elif type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif isinstance(negative_prompt, str):
|
|
uncond_tokens = [negative_prompt]
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
else:
|
|
uncond_tokens = negative_prompt
|
|
|
|
max_length = text_input_ids.shape[-1]
|
|
uncond_input = self.tokenizer(
|
|
uncond_tokens,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
|
attention_mask = uncond_input.attention_mask.to(device)
|
|
else:
|
|
attention_mask = None
|
|
|
|
uncond_embeddings = self.text_encoder(
|
|
uncond_input.input_ids.to(device),
|
|
attention_mask=attention_mask,
|
|
)
|
|
uncond_embeddings = uncond_embeddings[0]
|
|
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
seq_len = uncond_embeddings.shape[1]
|
|
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
# to avoid doing two forward passes
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
|
return text_embeddings
|
|
|
|
def run_safety_checker(self, image, device, dtype):
|
|
if self.safety_checker is not None:
|
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
|
image, has_nsfw_concept = self.safety_checker(
|
|
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
|
)
|
|
else:
|
|
has_nsfw_concept = None
|
|
return image, has_nsfw_concept
|
|
|
|
def decode_latents(self, latents):
|
|
latents = 1 / 0.18215 * latents
|
|
image = self.vae.decode(latents).sample
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
return image
|
|
|
|
def check_inputs(self, prompt, height, width, callback_steps):
|
|
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
|
|
if (callback_steps is None) or (
|
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
|
):
|
|
raise ValueError(
|
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
f" {type(callback_steps)}."
|
|
)
|
|
|
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
|
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
|
if latents is None:
|
|
if device.type == "mps":
|
|
# randn does not work reproducibly on mps
|
|
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
|
else:
|
|
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
else:
|
|
if latents.shape != shape:
|
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
latents = latents.to(device)
|
|
|
|
# scale the initial noise by the standard deviation required by the scheduler
|
|
return latents
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
height: int = 512,
|
|
width: int = 512,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[torch.Generator] = None,
|
|
latents: Optional[torch.FloatTensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
callback_steps: Optional[int] = 1,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Function invoked when calling the pipeline for generation.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`):
|
|
The prompt or prompts to guide the image generation.
|
|
height (`int`, *optional*, defaults to 512):
|
|
The height in pixels of the generated image.
|
|
width (`int`, *optional*, defaults to 512):
|
|
The width in pixels of the generated image.
|
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
usually at the expense of lower image quality.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
|
if `guidance_scale` is less than `1`).
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
eta (`float`, *optional*, defaults to 0.0):
|
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
|
generator (`torch.Generator`, *optional*):
|
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
|
deterministic.
|
|
latents (`torch.FloatTensor`, *optional*):
|
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
tensor will ge generated by sampling using the supplied random `generator`.
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generate image. Choose between
|
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
plain tuple.
|
|
callback (`Callable`, *optional*):
|
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
callback_steps (`int`, *optional*, defaults to 1):
|
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
|
called at every step.
|
|
|
|
Returns:
|
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
|
(nsfw) content, according to the `safety_checker`.
|
|
"""
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(prompt, height, width, callback_steps)
|
|
|
|
# 2. Define call parameters
|
|
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
|
device = self._execution_device
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
# corresponds to doing no classifier free guidance.
|
|
do_classifier_free_guidance = True
|
|
if guidance_scale <= 1.0:
|
|
raise ValueError("has to use guidance_scale")
|
|
|
|
# 3. Encode input prompt
|
|
text_embeddings = self._encode_prompt(
|
|
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
|
)
|
|
|
|
# 4. Prepare timesteps
|
|
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
|
|
sigmas = self.scheduler.sigmas
|
|
sigmas = sigmas.to(text_embeddings.dtype)
|
|
|
|
# 5. Prepare latent variables
|
|
num_channels_latents = self.unet.in_channels
|
|
latents = self.prepare_latents(
|
|
batch_size * num_images_per_prompt,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
text_embeddings.dtype,
|
|
device,
|
|
generator,
|
|
latents,
|
|
)
|
|
latents = latents * sigmas[0]
|
|
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
|
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
|
|
|
|
def model_fn(x, t):
|
|
latent_model_input = torch.cat([x] * 2)
|
|
|
|
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
return noise_pred
|
|
|
|
latents = self.sampler(model_fn, latents, sigmas)
|
|
|
|
# 8. Post-processing
|
|
image = self.decode_latents(latents)
|
|
|
|
# 9. Run safety checker
|
|
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
|
|
|
# 10. Convert to PIL
|
|
if output_type == "pil":
|
|
image = self.numpy_to_pil(image)
|
|
|
|
if not return_dict:
|
|
return (image, has_nsfw_concept)
|
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|