525 lines
26 KiB
Python
525 lines
26 KiB
Python
|
import inspect
|
||
|
import time
|
||
|
from pathlib import Path
|
||
|
from typing import Callable, List, Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from diffusers.configuration_utils import FrozenDict
|
||
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||
|
from diffusers.utils import deprecate, logging
|
||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||
|
|
||
|
|
||
|
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||
|
"""helper function to spherically interpolate two arrays v1 v2"""
|
||
|
|
||
|
if not isinstance(v0, np.ndarray):
|
||
|
inputs_are_torch = True
|
||
|
input_device = v0.device
|
||
|
v0 = v0.cpu().numpy()
|
||
|
v1 = v1.cpu().numpy()
|
||
|
|
||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||
|
v2 = (1 - t) * v0 + t * v1
|
||
|
else:
|
||
|
theta_0 = np.arccos(dot)
|
||
|
sin_theta_0 = np.sin(theta_0)
|
||
|
theta_t = theta_0 * t
|
||
|
sin_theta_t = np.sin(theta_t)
|
||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||
|
s1 = sin_theta_t / sin_theta_0
|
||
|
v2 = s0 * v0 + s1 * v1
|
||
|
|
||
|
if inputs_are_torch:
|
||
|
v2 = torch.from_numpy(v2).to(input_device)
|
||
|
|
||
|
return v2
|
||
|
|
||
|
|
||
|
class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||
|
r"""
|
||
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||
|
|
||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||
|
|
||
|
Args:
|
||
|
vae ([`AutoencoderKL`]):
|
||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||
|
text_encoder ([`CLIPTextModel`]):
|
||
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||
|
tokenizer (`CLIPTokenizer`):
|
||
|
Tokenizer of class
|
||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||
|
scheduler ([`SchedulerMixin`]):
|
||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||
|
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
vae: AutoencoderKL,
|
||
|
text_encoder: CLIPTextModel,
|
||
|
tokenizer: CLIPTokenizer,
|
||
|
unet: UNet2DConditionModel,
|
||
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||
|
safety_checker: StableDiffusionSafetyChecker,
|
||
|
feature_extractor: CLIPFeatureExtractor,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||
|
deprecation_message = (
|
||
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||
|
" file"
|
||
|
)
|
||
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||
|
new_config = dict(scheduler.config)
|
||
|
new_config["steps_offset"] = 1
|
||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||
|
|
||
|
if safety_checker is None:
|
||
|
logger.warn(
|
||
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||
|
)
|
||
|
|
||
|
self.register_modules(
|
||
|
vae=vae,
|
||
|
text_encoder=text_encoder,
|
||
|
tokenizer=tokenizer,
|
||
|
unet=unet,
|
||
|
scheduler=scheduler,
|
||
|
safety_checker=safety_checker,
|
||
|
feature_extractor=feature_extractor,
|
||
|
)
|
||
|
|
||
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||
|
r"""
|
||
|
Enable sliced attention computation.
|
||
|
|
||
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||
|
|
||
|
Args:
|
||
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||
|
`attention_head_dim` must be a multiple of `slice_size`.
|
||
|
"""
|
||
|
if slice_size == "auto":
|
||
|
# half the attention head size is usually a good trade-off between
|
||
|
# speed and memory
|
||
|
slice_size = self.unet.config.attention_head_dim // 2
|
||
|
self.unet.set_attention_slice(slice_size)
|
||
|
|
||
|
def disable_attention_slicing(self):
|
||
|
r"""
|
||
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||
|
back to computing attention in one step.
|
||
|
"""
|
||
|
# set slice_size = `None` to disable `attention slicing`
|
||
|
self.enable_attention_slicing(None)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def __call__(
|
||
|
self,
|
||
|
prompt: Optional[Union[str, List[str]]] = None,
|
||
|
height: int = 512,
|
||
|
width: int = 512,
|
||
|
num_inference_steps: int = 50,
|
||
|
guidance_scale: float = 7.5,
|
||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||
|
num_images_per_prompt: Optional[int] = 1,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
latents: Optional[torch.FloatTensor] = None,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
return_dict: bool = True,
|
||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||
|
callback_steps: Optional[int] = 1,
|
||
|
text_embeddings: Optional[torch.FloatTensor] = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
r"""
|
||
|
Function invoked when calling the pipeline for generation.
|
||
|
|
||
|
Args:
|
||
|
prompt (`str` or `List[str]`, *optional*, defaults to `None`):
|
||
|
The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
|
||
|
height (`int`, *optional*, defaults to 512):
|
||
|
The height in pixels of the generated image.
|
||
|
width (`int`, *optional*, defaults to 512):
|
||
|
The width in pixels of the generated image.
|
||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||
|
expense of slower inference.
|
||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||
|
usually at the expense of lower image quality.
|
||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||
|
if `guidance_scale` is less than `1`).
|
||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||
|
The number of images to generate per prompt.
|
||
|
eta (`float`, *optional*, defaults to 0.0):
|
||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||
|
generator (`torch.Generator`, *optional*):
|
||
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||
|
deterministic.
|
||
|
latents (`torch.FloatTensor`, *optional*):
|
||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||
|
The output format of the generate image. Choose between
|
||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||
|
plain tuple.
|
||
|
callback (`Callable`, *optional*):
|
||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||
|
called at every step.
|
||
|
text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||
|
Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
|
||
|
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
|
||
|
the supplied `prompt`.
|
||
|
|
||
|
Returns:
|
||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||
|
(nsfw) content, according to the `safety_checker`.
|
||
|
"""
|
||
|
|
||
|
if height % 8 != 0 or width % 8 != 0:
|
||
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||
|
|
||
|
if (callback_steps is None) or (
|
||
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||
|
):
|
||
|
raise ValueError(
|
||
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||
|
f" {type(callback_steps)}."
|
||
|
)
|
||
|
|
||
|
if text_embeddings is None:
|
||
|
if isinstance(prompt, str):
|
||
|
batch_size = 1
|
||
|
elif isinstance(prompt, list):
|
||
|
batch_size = len(prompt)
|
||
|
else:
|
||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||
|
|
||
|
# get prompt text embeddings
|
||
|
text_inputs = self.tokenizer(
|
||
|
prompt,
|
||
|
padding="max_length",
|
||
|
max_length=self.tokenizer.model_max_length,
|
||
|
return_tensors="pt",
|
||
|
)
|
||
|
text_input_ids = text_inputs.input_ids
|
||
|
|
||
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||
|
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||
|
print(
|
||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||
|
)
|
||
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||
|
else:
|
||
|
batch_size = text_embeddings.shape[0]
|
||
|
|
||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||
|
|
||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||
|
# corresponds to doing no classifier free guidance.
|
||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||
|
# get unconditional embeddings for classifier free guidance
|
||
|
if do_classifier_free_guidance:
|
||
|
uncond_tokens: List[str]
|
||
|
if negative_prompt is None:
|
||
|
uncond_tokens = [""]
|
||
|
elif type(prompt) is not type(negative_prompt):
|
||
|
raise TypeError(
|
||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||
|
f" {type(prompt)}."
|
||
|
)
|
||
|
elif isinstance(negative_prompt, str):
|
||
|
uncond_tokens = [negative_prompt]
|
||
|
elif batch_size != len(negative_prompt):
|
||
|
raise ValueError(
|
||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||
|
" the batch size of `prompt`."
|
||
|
)
|
||
|
else:
|
||
|
uncond_tokens = negative_prompt
|
||
|
|
||
|
max_length = self.tokenizer.model_max_length
|
||
|
uncond_input = self.tokenizer(
|
||
|
uncond_tokens,
|
||
|
padding="max_length",
|
||
|
max_length=max_length,
|
||
|
truncation=True,
|
||
|
return_tensors="pt",
|
||
|
)
|
||
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||
|
|
||
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||
|
seq_len = uncond_embeddings.shape[1]
|
||
|
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||
|
|
||
|
# For classifier free guidance, we need to do two forward passes.
|
||
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||
|
# to avoid doing two forward passes
|
||
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||
|
|
||
|
# get the initial random noise unless the user supplied it
|
||
|
|
||
|
# Unlike in other pipelines, latents need to be generated in the target device
|
||
|
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||
|
# However this currently doesn't work in `mps`.
|
||
|
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||
|
latents_dtype = text_embeddings.dtype
|
||
|
if latents is None:
|
||
|
if self.device.type == "mps":
|
||
|
# randn does not exist on mps
|
||
|
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||
|
self.device
|
||
|
)
|
||
|
else:
|
||
|
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||
|
else:
|
||
|
if latents.shape != latents_shape:
|
||
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||
|
latents = latents.to(self.device)
|
||
|
|
||
|
# set timesteps
|
||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||
|
|
||
|
# Some schedulers like PNDM have timesteps as arrays
|
||
|
# It's more optimized to move all timesteps to correct device beforehand
|
||
|
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||
|
|
||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||
|
latents = latents * self.scheduler.init_noise_sigma
|
||
|
|
||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||
|
# and should be between [0, 1]
|
||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||
|
extra_step_kwargs = {}
|
||
|
if accepts_eta:
|
||
|
extra_step_kwargs["eta"] = eta
|
||
|
|
||
|
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||
|
# expand the latents if we are doing classifier free guidance
|
||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||
|
|
||
|
# predict the noise residual
|
||
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||
|
|
||
|
# perform guidance
|
||
|
if do_classifier_free_guidance:
|
||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||
|
|
||
|
# compute the previous noisy sample x_t -> x_t-1
|
||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||
|
|
||
|
# call the callback, if provided
|
||
|
if callback is not None and i % callback_steps == 0:
|
||
|
callback(i, t, latents)
|
||
|
|
||
|
latents = 1 / 0.18215 * latents
|
||
|
image = self.vae.decode(latents).sample
|
||
|
|
||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||
|
|
||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||
|
|
||
|
if self.safety_checker is not None:
|
||
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||
|
self.device
|
||
|
)
|
||
|
image, has_nsfw_concept = self.safety_checker(
|
||
|
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||
|
)
|
||
|
else:
|
||
|
has_nsfw_concept = None
|
||
|
|
||
|
if output_type == "pil":
|
||
|
image = self.numpy_to_pil(image)
|
||
|
|
||
|
if not return_dict:
|
||
|
return (image, has_nsfw_concept)
|
||
|
|
||
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||
|
|
||
|
def embed_text(self, text):
|
||
|
"""takes in text and turns it into text embeddings"""
|
||
|
text_input = self.tokenizer(
|
||
|
text,
|
||
|
padding="max_length",
|
||
|
max_length=self.tokenizer.model_max_length,
|
||
|
truncation=True,
|
||
|
return_tensors="pt",
|
||
|
)
|
||
|
with torch.no_grad():
|
||
|
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||
|
return embed
|
||
|
|
||
|
def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
|
||
|
"""Takes in random seed and returns corresponding noise vector"""
|
||
|
return torch.randn(
|
||
|
(1, self.unet.in_channels, height // 8, width // 8),
|
||
|
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||
|
device=self.device,
|
||
|
dtype=dtype,
|
||
|
)
|
||
|
|
||
|
def walk(
|
||
|
self,
|
||
|
prompts: List[str],
|
||
|
seeds: List[int],
|
||
|
num_interpolation_steps: Optional[int] = 6,
|
||
|
output_dir: Optional[str] = "./dreams",
|
||
|
name: Optional[str] = None,
|
||
|
batch_size: Optional[int] = 1,
|
||
|
height: Optional[int] = 512,
|
||
|
width: Optional[int] = 512,
|
||
|
guidance_scale: Optional[float] = 7.5,
|
||
|
num_inference_steps: Optional[int] = 50,
|
||
|
eta: Optional[float] = 0.0,
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.
|
||
|
|
||
|
Args:
|
||
|
prompts (`List[str]`):
|
||
|
List of prompts to generate images for.
|
||
|
seeds (`List[int]`):
|
||
|
List of seeds corresponding to provided prompts. Must be the same length as prompts.
|
||
|
num_interpolation_steps (`int`, *optional*, defaults to 6):
|
||
|
Number of interpolation steps to take between prompts.
|
||
|
output_dir (`str`, *optional*, defaults to `./dreams`):
|
||
|
Directory to save the generated images to.
|
||
|
name (`str`, *optional*, defaults to `None`):
|
||
|
Subdirectory of `output_dir` to save the generated images to. If `None`, the name will
|
||
|
be the current time.
|
||
|
batch_size (`int`, *optional*, defaults to 1):
|
||
|
Number of images to generate at once.
|
||
|
height (`int`, *optional*, defaults to 512):
|
||
|
Height of the generated images.
|
||
|
width (`int`, *optional*, defaults to 512):
|
||
|
Width of the generated images.
|
||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||
|
usually at the expense of lower image quality.
|
||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||
|
expense of slower inference.
|
||
|
eta (`float`, *optional*, defaults to 0.0):
|
||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||
|
|
||
|
Returns:
|
||
|
`List[str]`: List of paths to the generated images.
|
||
|
"""
|
||
|
if not len(prompts) == len(seeds):
|
||
|
raise ValueError(
|
||
|
f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"
|
||
|
)
|
||
|
|
||
|
name = name or time.strftime("%Y%m%d-%H%M%S")
|
||
|
save_path = Path(output_dir) / name
|
||
|
save_path.mkdir(exist_ok=True, parents=True)
|
||
|
|
||
|
frame_idx = 0
|
||
|
frame_filepaths = []
|
||
|
for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):
|
||
|
# Embed Text
|
||
|
embed_a = self.embed_text(prompt_a)
|
||
|
embed_b = self.embed_text(prompt_b)
|
||
|
|
||
|
# Get Noise
|
||
|
noise_dtype = embed_a.dtype
|
||
|
noise_a = self.get_noise(seed_a, noise_dtype, height, width)
|
||
|
noise_b = self.get_noise(seed_b, noise_dtype, height, width)
|
||
|
|
||
|
noise_batch, embeds_batch = None, None
|
||
|
T = np.linspace(0.0, 1.0, num_interpolation_steps)
|
||
|
for i, t in enumerate(T):
|
||
|
noise = slerp(float(t), noise_a, noise_b)
|
||
|
embed = torch.lerp(embed_a, embed_b, t)
|
||
|
|
||
|
noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)
|
||
|
embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)
|
||
|
|
||
|
batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
|
||
|
if batch_is_ready:
|
||
|
outputs = self(
|
||
|
latents=noise_batch,
|
||
|
text_embeddings=embeds_batch,
|
||
|
height=height,
|
||
|
width=width,
|
||
|
guidance_scale=guidance_scale,
|
||
|
eta=eta,
|
||
|
num_inference_steps=num_inference_steps,
|
||
|
)
|
||
|
noise_batch, embeds_batch = None, None
|
||
|
|
||
|
for image in outputs["images"]:
|
||
|
frame_filepath = str(save_path / f"frame_{frame_idx}.png")
|
||
|
image.save(frame_filepath)
|
||
|
frame_filepaths.append(frame_filepath)
|
||
|
frame_idx += 1
|
||
|
return frame_filepaths
|