228 lines
10 KiB
Python
228 lines
10 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
import PIL.Image
|
|
import torch
|
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DDIMScheduler,
|
|
DiffusionPipeline,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
StableDiffusionImg2ImgPipeline,
|
|
StableDiffusionInpaintPipelineLegacy,
|
|
StableDiffusionPipeline,
|
|
UNet2DConditionModel,
|
|
)
|
|
from diffusers.configuration_utils import FrozenDict
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
from diffusers.utils import deprecate, logging
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class StableDiffusionMegaPipeline(DiffusionPipeline):
|
|
r"""
|
|
Pipeline for text-to-image generation using Stable Diffusion.
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
|
Args:
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
text_encoder ([`CLIPTextModel`]):
|
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
tokenizer (`CLIPTokenizer`):
|
|
Tokenizer of class
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
scheduler ([`SchedulerMixin`]):
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
|
"""
|
|
_optional_components = ["safety_checker", "feature_extractor"]
|
|
|
|
def __init__(
|
|
self,
|
|
vae: AutoencoderKL,
|
|
text_encoder: CLIPTextModel,
|
|
tokenizer: CLIPTokenizer,
|
|
unet: UNet2DConditionModel,
|
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
|
safety_checker: StableDiffusionSafetyChecker,
|
|
feature_extractor: CLIPFeatureExtractor,
|
|
requires_safety_checker: bool = True,
|
|
):
|
|
super().__init__()
|
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
|
deprecation_message = (
|
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
|
" file"
|
|
)
|
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
|
new_config = dict(scheduler.config)
|
|
new_config["steps_offset"] = 1
|
|
scheduler._internal_dict = FrozenDict(new_config)
|
|
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
safety_checker=safety_checker,
|
|
feature_extractor=feature_extractor,
|
|
)
|
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
|
|
|
@property
|
|
def components(self) -> Dict[str, Any]:
|
|
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
|
|
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
|
r"""
|
|
Enable sliced attention computation.
|
|
|
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
|
|
|
Args:
|
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
|
`attention_head_dim` must be a multiple of `slice_size`.
|
|
"""
|
|
if slice_size == "auto":
|
|
# half the attention head size is usually a good trade-off between
|
|
# speed and memory
|
|
slice_size = self.unet.config.attention_head_dim // 2
|
|
self.unet.set_attention_slice(slice_size)
|
|
|
|
def disable_attention_slicing(self):
|
|
r"""
|
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
|
back to computing attention in one step.
|
|
"""
|
|
# set slice_size = `None` to disable `attention slicing`
|
|
self.enable_attention_slicing(None)
|
|
|
|
@torch.no_grad()
|
|
def inpaint(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
image: Union[torch.FloatTensor, PIL.Image.Image],
|
|
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
|
strength: float = 0.8,
|
|
num_inference_steps: Optional[int] = 50,
|
|
guidance_scale: Optional[float] = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: Optional[float] = 0.0,
|
|
generator: Optional[torch.Generator] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
callback_steps: int = 1,
|
|
):
|
|
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
|
|
return StableDiffusionInpaintPipelineLegacy(**self.components)(
|
|
prompt=prompt,
|
|
image=image,
|
|
mask_image=mask_image,
|
|
strength=strength,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
eta=eta,
|
|
generator=generator,
|
|
output_type=output_type,
|
|
return_dict=return_dict,
|
|
callback=callback,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def img2img(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
image: Union[torch.FloatTensor, PIL.Image.Image],
|
|
strength: float = 0.8,
|
|
num_inference_steps: Optional[int] = 50,
|
|
guidance_scale: Optional[float] = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: Optional[float] = 0.0,
|
|
generator: Optional[torch.Generator] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
callback_steps: int = 1,
|
|
**kwargs,
|
|
):
|
|
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
|
|
return StableDiffusionImg2ImgPipeline(**self.components)(
|
|
prompt=prompt,
|
|
image=image,
|
|
strength=strength,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
eta=eta,
|
|
generator=generator,
|
|
output_type=output_type,
|
|
return_dict=return_dict,
|
|
callback=callback,
|
|
callback_steps=callback_steps,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def text2img(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
height: int = 512,
|
|
width: int = 512,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[torch.Generator] = None,
|
|
latents: Optional[torch.FloatTensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
callback_steps: int = 1,
|
|
):
|
|
# For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
|
|
return StableDiffusionPipeline(**self.components)(
|
|
prompt=prompt,
|
|
height=height,
|
|
width=width,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
eta=eta,
|
|
generator=generator,
|
|
latents=latents,
|
|
output_type=output_type,
|
|
return_dict=return_dict,
|
|
callback=callback,
|
|
callback_steps=callback_steps,
|
|
)
|