Refactor RiffusionPipeline

* Bunch of cleanup and typing
 * Move prompt_weighting to be marked as external
 * Add helpers for loading the checkpoint and traced unet

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:22:56 -08:00
parent 4c78e1a228
commit 2fb1153ec8
4 changed files with 178 additions and 80 deletions

3
riffusion/external/README.md vendored Normal file
View File

@ -0,0 +1,3 @@
# external
This package contains scripts and tools from external sources.

0
riffusion/external/__init__.py vendored Normal file
View File

View File

@ -5,6 +5,9 @@ This code is taken from the diffusers community pipeline:
License: Apache 2.0
"""
# ruff: noqa
# mypy: ignore-errors
import re
from typing import List, Optional, Union
@ -193,7 +196,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
) -> torch.FloatTensor:
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
@ -239,7 +242,7 @@ def get_weighted_text_embeddings(
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
**kwargs,
):
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
@ -269,8 +272,6 @@ def get_weighted_text_embeddings(
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
print(f"tokens: {prompt_tokens}")
print(f"weights: {prompt_weights}")
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):

View File

@ -1,13 +1,15 @@
"""
Riffusion inference pipeline.
"""
from __future__ import annotations
import dataclasses
import functools
import inspect
import typing as T
import numpy as np
import PIL
from PIL import Image
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -15,9 +17,12 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from huggingface_hub import hf_hub_download
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from .datatypes import InferenceInput
from riffusion.datatypes import InferenceInput
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
from riffusion.util import torch_util
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -56,8 +61,110 @@ class RiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
@classmethod
def load_checkpoint(
cls,
checkpoint: str,
use_traced_unet: bool = True,
channels_last: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
Args:
checkpoint: Model checkpoint on disk in diffusers format
use_traced_unet: Whether to use the traced unet for speedups
device: Device to load the model on
channels_last: Whether to use channels_last memory format
"""
device = torch_util.check_device(device)
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
# Disable the NSFW filter, causes incorrect false positives
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
safety_checker=lambda images, **kwargs: (images, False),
# Optionally attempt to use less memory
low_cpu_mem_usage=False,
).to(device)
if channels_last:
pipeline.unet.to(memory_format=torch.channels_last)
# Optionally load a traced unet
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
traced_unet = cls.load_traced_unet(
checkpoint=checkpoint,
subfolder="unet_traced",
filename="unet_traced.pt",
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
)
if traced_unet is not None:
pipeline.unet = traced_unet
model = pipeline.to(device)
return model
@staticmethod
def load_traced_unet(
checkpoint: str,
subfolder: str,
filename: str,
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
"""
if device == "cpu" or device.lower().startswith("mps"):
print("WARNING: Traced UNet only available for CUDA, skipping")
return None
# Download and load the traced unet
unet_file = hf_hub_download(
checkpoint,
subfolder=subfolder,
filename=filename,
)
unet_traced = torch.jit.load(unet_file)
# Wrap it in a torch module
class TracedUNet(torch.nn.Module):
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
def __init__(self):
super().__init__()
self.in_channels = device
self.device = device
self.dtype = dtype
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return self.UNet2DConditionOutput(sample=sample)
return TracedUNet()
@property
def device(self) -> str:
return str(self.vae.device)
@functools.lru_cache()
def embed_text(self, text):
def embed_text(self, text) -> torch.FloatTensor:
"""
Takes in text and turns it into text embeddings.
"""
@ -73,12 +180,10 @@ class RiffusionPipeline(DiffusionPipeline):
return embed
@functools.lru_cache()
def embed_text_weighted(self, text):
def embed_text_weighted(self, text) -> torch.FloatTensor:
"""
Get text embedding with weights.
"""
from .prompt_weighting import get_weighted_text_embeddings
return get_weighted_text_embeddings(
pipe=self,
prompt=text,
@ -93,10 +198,10 @@ class RiffusionPipeline(DiffusionPipeline):
def riffuse(
self,
inputs: InferenceInput,
init_image: PIL.Image.Image,
mask_image: PIL.Image.Image = None,
init_image: Image.Image,
mask_image: T.Optional[Image.Image] = None,
use_reweighting: bool = True,
) -> PIL.Image.Image:
) -> Image.Image:
"""
Runs inference using interpolation with both img2img and text conditioning.
@ -113,8 +218,14 @@ class RiffusionPipeline(DiffusionPipeline):
end = inputs.end
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# TODO(hayk): Always generate the seed on CPU?
if self.device.lower().startswith("mps"):
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
else:
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# Text encodings
if use_reweighting:
@ -123,25 +234,31 @@ class RiffusionPipeline(DiffusionPipeline):
else:
embed_start = self.embed_text(start.prompt)
embed_end = self.embed_text(end.prompt)
text_embedding = torch.lerp(embed_start, embed_end, alpha)
text_embedding = embed_start + alpha * (embed_end - embed_start)
# Image latents
init_image = preprocess_image(init_image)
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
init_image_torch = preprocess_image(init_image).to(
device=self.device, dtype=embed_start.dtype
)
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
# result is so close no matter the seed that it doesn't really add variety.
generator = torch.Generator(device=self.device).manual_seed(start.seed)
if self.device.lower().startswith("mps"):
generator = torch.Generator(device="cpu").manual_seed(start.seed)
else:
generator = torch.Generator(device=self.device).manual_seed(start.seed)
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Prepare mask latent
mask: T.Optional[torch.Tensor] = None
if mask_image:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
else:
mask = None
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
device=self.device, dtype=embed_start.dtype
)
outputs = self.interpolate_img2img(
text_embeddings=text_embedding,
@ -161,18 +278,18 @@ class RiffusionPipeline(DiffusionPipeline):
@torch.no_grad()
def interpolate_img2img(
self,
text_embeddings: torch.FloatTensor,
init_latents: torch.FloatTensor,
text_embeddings: torch.Tensor,
init_latents: torch.Tensor,
generator_a: torch.Generator,
generator_b: torch.Generator,
interpolate_alpha: float,
mask: T.Optional[torch.FloatTensor] = None,
mask: T.Optional[torch.Tensor] = None,
strength_a: float = 0.8,
strength_b: float = 0.8,
num_inference_steps: T.Optional[int] = 50,
guidance_scale: T.Optional[float] = 7.5,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
num_images_per_prompt: T.Optional[int] = 1,
num_images_per_prompt: int = 1,
eta: T.Optional[float] = 0.0,
output_type: T.Optional[str] = "pil",
**kwargs,
@ -198,11 +315,6 @@ class RiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
if negative_prompt is None:
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
@ -251,11 +363,11 @@ class RiffusionPipeline(DiffusionPipeline):
noise_b = torch.randn(
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
)
noise = slerp(interpolate_alpha, noise_a, noise_b)
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
init_latents_orig = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
@ -295,7 +407,9 @@ class RiffusionPipeline(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
init_latents_proper = self.scheduler.add_noise(
init_latents_orig, noise, torch.tensor([t])
)
# import ipdb; ipdb.set_trace()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
@ -311,62 +425,42 @@ class RiffusionPipeline(DiffusionPipeline):
return dict(images=image, latents=latents, nsfw_content_detected=False)
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess an image for the model.
"""
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
image = image.resize((w, h), resample=Image.LANCZOS)
image_np = np.array(image).astype(np.float32) / 255.0
image_np = image_np[None].transpose(0, 3, 1, 2)
image_torch = torch.from_numpy(image_np)
return 2.0 * image_torch - 1.0
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
"""
Preprocess a mask for the model.
"""
# Convert to grayscale
mask = mask.convert("L")
# Resize to integer multiple of 32
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize(
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
w, h = map(lambda x: x - x % 32, (w, h))
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
return mask
# Convert to numpy array and rescale
mask_np = np.array(mask).astype(np.float32) / 255.0
# Tile and transpose
mask_np = np.tile(mask_np, (4, 1, 1))
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
def slerp(t, v0, v1, dot_threshold=0.9995):
"""
Helper function to spherically interpolate two arrays v1 v2.
"""
# Invert to repaint white and keep black
mask_np = 1 - mask_np # repaint white, keep black
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > dot_threshold:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
return torch.from_numpy(mask_np)