406 lines
17 KiB
Python
406 lines
17 KiB
Python
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from diffusers import (
|
||
|
AutoencoderKL,
|
||
|
DDIMScheduler,
|
||
|
DiffusionPipeline,
|
||
|
LMSDiscreteScheduler,
|
||
|
PNDMScheduler,
|
||
|
StableDiffusionPipeline,
|
||
|
UNet2DConditionModel,
|
||
|
)
|
||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||
|
|
||
|
|
||
|
pipe1_model_id = "CompVis/stable-diffusion-v1-1"
|
||
|
pipe2_model_id = "CompVis/stable-diffusion-v1-2"
|
||
|
pipe3_model_id = "CompVis/stable-diffusion-v1-3"
|
||
|
pipe4_model_id = "CompVis/stable-diffusion-v1-4"
|
||
|
|
||
|
|
||
|
class StableDiffusionComparisonPipeline(DiffusionPipeline):
|
||
|
r"""
|
||
|
Pipeline for parallel comparison of Stable Diffusion v1-v4
|
||
|
This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
|
||
|
downloading pre-trained checkpoints from Hugging Face Hub.
|
||
|
If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
|
||
|
automatically.
|
||
|
Args:
|
||
|
vae ([`AutoencoderKL`]):
|
||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||
|
text_encoder ([`CLIPTextModel`]):
|
||
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||
|
tokenizer (`CLIPTokenizer`):
|
||
|
Tokenizer of class
|
||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||
|
scheduler ([`SchedulerMixin`]):
|
||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||
|
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
||
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
vae: AutoencoderKL,
|
||
|
text_encoder: CLIPTextModel,
|
||
|
tokenizer: CLIPTokenizer,
|
||
|
unet: UNet2DConditionModel,
|
||
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||
|
safety_checker: StableDiffusionSafetyChecker,
|
||
|
feature_extractor: CLIPFeatureExtractor,
|
||
|
requires_safety_checker: bool = True,
|
||
|
):
|
||
|
super()._init_()
|
||
|
|
||
|
self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
|
||
|
self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
|
||
|
self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
|
||
|
self.pipe4 = StableDiffusionPipeline(
|
||
|
vae=vae,
|
||
|
text_encoder=text_encoder,
|
||
|
tokenizer=tokenizer,
|
||
|
unet=unet,
|
||
|
scheduler=scheduler,
|
||
|
safety_checker=safety_checker,
|
||
|
feature_extractor=feature_extractor,
|
||
|
requires_safety_checker=requires_safety_checker,
|
||
|
)
|
||
|
|
||
|
self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
|
||
|
|
||
|
@property
|
||
|
def layers(self) -> Dict[str, Any]:
|
||
|
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||
|
|
||
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||
|
r"""
|
||
|
Enable sliced attention computation.
|
||
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||
|
Args:
|
||
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||
|
`attention_head_dim` must be a multiple of `slice_size`.
|
||
|
"""
|
||
|
if slice_size == "auto":
|
||
|
# half the attention head size is usually a good trade-off between
|
||
|
# speed and memory
|
||
|
slice_size = self.unet.config.attention_head_dim // 2
|
||
|
self.unet.set_attention_slice(slice_size)
|
||
|
|
||
|
def disable_attention_slicing(self):
|
||
|
r"""
|
||
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||
|
back to computing attention in one step.
|
||
|
"""
|
||
|
# set slice_size = `None` to disable `attention slicing`
|
||
|
self.enable_attention_slicing(None)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def text2img_sd1_1(
|
||
|
self,
|
||
|
prompt: Union[str, List[str]],
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
return self.pipe1(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def text2img_sd1_2(
|
||
|
self,
|
||
|
prompt: Union[str, List[str]],
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
return self.pipe2(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def text2img_sd1_3(
|
||
|
self,
|
||
|
prompt: Union[str, List[str]],
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
return self.pipe3(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def text2img_sd1_4(
|
||
|
self,
|
||
|
prompt: Union[str, List[str]],
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
return self.pipe4(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def _call_(
|
||
|
self,
|
||
|
prompt: Union[str, List[str]],
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
r"""
|
||
|
Function invoked when calling the pipeline for generation. This function will generate 4 results as part
|
||
|
of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
|
||
|
Args:
|
||
|
prompt (`str` or `List[str]`):
|
||
|
The prompt or prompts to guide the image generation.
|
||
|
height (`int`, optional, defaults to 512):
|
||
|
The height in pixels of the generated image.
|
||
|
width (`int`, optional, defaults to 512):
|
||
|
The width in pixels of the generated image.
|
||
|
num_inference_steps (`int`, optional, defaults to 50):
|
||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||
|
expense of slower inference.
|
||
|
guidance_scale (`float`, optional, defaults to 7.5):
|
||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||
|
usually at the expense of lower image quality.
|
||
|
eta (`float`, optional, defaults to 0.0):
|
||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||
|
generator (`torch.Generator`, optional):
|
||
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||
|
deterministic.
|
||
|
latents (`torch.FloatTensor`, optional):
|
||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||
|
output_type (`str`, optional, defaults to `"pil"`):
|
||
|
The output format of the generate image. Choose between
|
||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||
|
return_dict (`bool`, optional, defaults to `True`):
|
||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||
|
plain tuple.
|
||
|
Returns:
|
||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||
|
(nsfw) content, according to the `safety_checker`.
|
||
|
"""
|
||
|
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
self.to(device)
|
||
|
|
||
|
# Checks if the height and width are divisible by 8 or not
|
||
|
if height % 8 != 0 or width % 8 != 0:
|
||
|
raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
|
||
|
|
||
|
# Get first result from Stable Diffusion Checkpoint v1.1
|
||
|
res1 = self.text2img_sd1_1(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
# Get first result from Stable Diffusion Checkpoint v1.2
|
||
|
res2 = self.text2img_sd1_2(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
# Get first result from Stable Diffusion Checkpoint v1.3
|
||
|
res3 = self.text2img_sd1_3(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
# Get first result from Stable Diffusion Checkpoint v1.4
|
||
|
res4 = self.text2img_sd1_4(
|
||
|
prompt=prompt,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
guidance_scale=guidance_scale,
|
||
|
negative_prompt=negative_prompt,
|
||
|
num_images_per_prompt=num_images_per_prompt,
|
||
|
eta=eta,
|
||
|
generator=generator,
|
||
|
latents=latents,
|
||
|
output_type=output_type,
|
||
|
return_dict=return_dict,
|
||
|
callback=callback,
|
||
|
callback_steps=callback_steps,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
# Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
|
||
|
return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
|