riffusion-inference/riffusion/riffusion_pipeline.py

478 lines
18 KiB
Python

"""
Riffusion inference pipeline.
"""
from __future__ import annotations
import dataclasses
import functools
import inspect
import typing as T
import numpy as np
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from riffusion.datatypes import InferenceInput
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
from riffusion.util import torch_util
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class RiffusionPipeline(DiffusionPipeline):
"""
Diffusers pipeline for doing a controlled img2img interpolation for audio generation.
# TODO(hayk): Document more
Part of this code was adapted from the non-img2img interpolation pipeline at:
https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py
Check the documentation for DiffusionPipeline for full information.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: T.Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@classmethod
def load_checkpoint(
cls,
checkpoint: str,
use_traced_unet: bool = True,
channels_last: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
local_files_only: bool = False,
low_cpu_mem_usage: bool = False,
cache_dir: T.Optional[str] = None,
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
Args:
checkpoint: Model checkpoint on disk in diffusers format
use_traced_unet: Whether to use the traced unet for speedups
device: Device to load the model on
channels_last: Whether to use channels_last memory format
local_files_only: Don't download, only use local files
low_cpu_mem_usage: Attempt to use less memory on CPU
"""
device = torch_util.check_device(device)
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
# Disable the NSFW filter, causes incorrect false positives
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
safety_checker=lambda images, **kwargs: (images, False),
low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only,
cache_dir=cache_dir,
).to(device)
if channels_last:
pipeline.unet.to(memory_format=torch.channels_last)
# Optionally load a traced unet
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
traced_unet = cls.load_traced_unet(
checkpoint=checkpoint,
subfolder="unet_traced",
filename="unet_traced.pt",
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
if traced_unet is not None:
pipeline.unet = traced_unet
model = pipeline.to(device)
return model
@staticmethod
def load_traced_unet(
checkpoint: str,
subfolder: str,
filename: str,
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
local_files_only=False,
cache_dir: T.Optional[str] = None,
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
"""
if device == "cpu" or device.lower().startswith("mps"):
print("WARNING: Traced UNet only available for CUDA, skipping")
return None
# Download and load the traced unet
unet_file = hf_hub_download(
checkpoint,
subfolder=subfolder,
filename=filename,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
unet_traced = torch.jit.load(unet_file)
# Wrap it in a torch module
class TracedUNet(torch.nn.Module):
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
def __init__(self):
super().__init__()
self.in_channels = device
self.device = device
self.dtype = dtype
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return self.UNet2DConditionOutput(sample=sample)
return TracedUNet()
@property
def device(self) -> str:
return str(self.vae.device)
@functools.lru_cache()
def embed_text(self, text) -> torch.FloatTensor:
"""
Takes in text and turns it into text embeddings.
"""
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed
@functools.lru_cache()
def embed_text_weighted(self, text) -> torch.FloatTensor:
"""
Get text embedding with weights.
"""
return get_weighted_text_embeddings(
pipe=self,
prompt=text,
uncond_prompt=None,
max_embeddings_multiples=3,
no_boseos_middle=False,
skip_parsing=False,
skip_weighting=False,
)[0]
@torch.no_grad()
def riffuse(
self,
inputs: InferenceInput,
init_image: Image.Image,
mask_image: T.Optional[Image.Image] = None,
use_reweighting: bool = True,
) -> Image.Image:
"""
Runs inference using interpolation with both img2img and text conditioning.
Args:
inputs: Parameter dataclass
init_image: Image used for conditioning
mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
while black pixels will be preserved. It will be converted to a single
channel (luminance) before use.
use_reweighting: Use prompt reweighting
"""
alpha = inputs.alpha
start = inputs.start
end = inputs.end
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
# TODO(hayk): Always generate the seed on CPU?
if self.device.lower().startswith("mps"):
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
else:
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# Text encodings
if use_reweighting:
embed_start = self.embed_text_weighted(start.prompt)
embed_end = self.embed_text_weighted(end.prompt)
else:
embed_start = self.embed_text(start.prompt)
embed_end = self.embed_text(end.prompt)
text_embedding = embed_start + alpha * (embed_end - embed_start)
# Image latents
init_image_torch = preprocess_image(init_image).to(
device=self.device, dtype=embed_start.dtype
)
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
# result is so close no matter the seed that it doesn't really add variety.
if self.device.lower().startswith("mps"):
generator = torch.Generator(device="cpu").manual_seed(start.seed)
else:
generator = torch.Generator(device=self.device).manual_seed(start.seed)
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Prepare mask latent
mask: T.Optional[torch.Tensor] = None
if mask_image:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
device=self.device, dtype=embed_start.dtype
)
outputs = self.interpolate_img2img(
text_embeddings=text_embedding,
init_latents=init_latents,
mask=mask,
generator_a=generator_start,
generator_b=generator_end,
interpolate_alpha=alpha,
strength_a=start.denoising,
strength_b=end.denoising,
num_inference_steps=inputs.num_inference_steps,
guidance_scale=guidance_scale,
)
return outputs["images"][0]
@torch.no_grad()
def interpolate_img2img(
self,
text_embeddings: torch.Tensor,
init_latents: torch.Tensor,
generator_a: torch.Generator,
generator_b: torch.Generator,
interpolate_alpha: float,
mask: T.Optional[torch.Tensor] = None,
strength_a: float = 0.8,
strength_b: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
num_images_per_prompt: int = 1,
eta: T.Optional[float] = 0.0,
output_type: T.Optional[str] = "pil",
**kwargs,
):
"""
TODO
"""
batch_size = text_embeddings.shape[0]
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if negative_prompt is None:
uncond_tokens = [""]
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
uncond_tokens = negative_prompt
# max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(
batch_size * num_images_per_prompt, dim=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents_dtype = text_embeddings.dtype
strength = (1 - interpolate_alpha) * strength_a + interpolate_alpha * strength_b
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor(
[timesteps] * batch_size * num_images_per_prompt, device=self.device
)
# add noise to latents using the timesteps
noise_a = torch.randn(
init_latents.shape, generator=generator_a, device=self.device, dtype=latents_dtype
)
noise_b = torch.randn(
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
)
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
init_latents_orig = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents.clone()
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
init_latents_proper = self.scheduler.add_noise(
init_latents_orig, noise, torch.tensor([t])
)
# import ipdb; ipdb.set_trace()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
latents = 1.0 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return dict(images=image, latents=latents, nsfw_content_detected=False)
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess an image for the model.
"""
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.LANCZOS)
image_np = np.array(image).astype(np.float32) / 255.0
image_np = image_np[None].transpose(0, 3, 1, 2)
image_torch = torch.from_numpy(image_np)
return 2.0 * image_torch - 1.0
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
"""
Preprocess a mask for the model.
"""
# Convert to grayscale
mask = mask.convert("L")
# Resize to integer multiple of 32
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h))
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
# Convert to numpy array and rescale
mask_np = np.array(mask).astype(np.float32) / 255.0
# Tile and transpose
mask_np = np.tile(mask_np, (4, 1, 1))
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
# Invert to repaint white and keep black
mask_np = 1 - mask_np # repaint white, keep black
return torch.from_numpy(mask_np)