Refactor RiffusionPipeline
* Bunch of cleanup and typing * Move prompt_weighting to be marked as external * Add helpers for loading the checkpoint and traced unet Topic: clean_rewrite
This commit is contained in:
parent
4c78e1a228
commit
2fb1153ec8
|
@ -0,0 +1,3 @@
|
|||
# external
|
||||
|
||||
This package contains scripts and tools from external sources.
|
|
@ -5,6 +5,9 @@ This code is taken from the diffusers community pipeline:
|
|||
|
||||
License: Apache 2.0
|
||||
"""
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
@ -193,7 +196,7 @@ def get_unweighted_text_embeddings(
|
|||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
|
@ -239,7 +242,7 @@ def get_weighted_text_embeddings(
|
|||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||
|
@ -269,8 +272,6 @@ def get_weighted_text_embeddings(
|
|||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
||||
print(f"tokens: {prompt_tokens}")
|
||||
print(f"weights: {prompt_weights}")
|
||||
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
|
@ -1,13 +1,15 @@
|
|||
"""
|
||||
Riffusion inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
|
@ -15,9 +17,12 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import logging
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .datatypes import InferenceInput
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
|
||||
from riffusion.util import torch_util
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
@ -56,8 +61,110 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(
|
||||
cls,
|
||||
checkpoint: str,
|
||||
use_traced_unet: bool = True,
|
||||
channels_last: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: str = "cuda",
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
|
||||
Args:
|
||||
checkpoint: Model checkpoint on disk in diffusers format
|
||||
use_traced_unet: Whether to use the traced unet for speedups
|
||||
device: Device to load the model on
|
||||
channels_last: Whether to use channels_last memory format
|
||||
"""
|
||||
device = torch_util.check_device(device)
|
||||
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
pipeline = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
# Optionally attempt to use less memory
|
||||
low_cpu_mem_usage=False,
|
||||
).to(device)
|
||||
|
||||
if channels_last:
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
# Optionally load a traced unet
|
||||
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
|
||||
traced_unet = cls.load_traced_unet(
|
||||
checkpoint=checkpoint,
|
||||
subfolder="unet_traced",
|
||||
filename="unet_traced.pt",
|
||||
in_channels=pipeline.unet.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if traced_unet is not None:
|
||||
pipeline.unet = traced_unet
|
||||
|
||||
model = pipeline.to(device)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def load_traced_unet(
|
||||
checkpoint: str,
|
||||
subfolder: str,
|
||||
filename: str,
|
||||
in_channels: int,
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
) -> T.Optional[torch.nn.Module]:
|
||||
"""
|
||||
Load a traced unet from the huggingface hub. This can improve performance.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print("WARNING: Traced UNet only available for CUDA, skipping")
|
||||
return None
|
||||
|
||||
# Download and load the traced unet
|
||||
unet_file = hf_hub_download(
|
||||
checkpoint,
|
||||
subfolder=subfolder,
|
||||
filename=filename,
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
# Wrap it in a torch module
|
||||
class TracedUNet(torch.nn.Module):
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = device
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return self.UNet2DConditionOutput(sample=sample)
|
||||
|
||||
return TracedUNet()
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return str(self.vae.device)
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text(self, text):
|
||||
def embed_text(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Takes in text and turns it into text embeddings.
|
||||
"""
|
||||
|
@ -73,12 +180,10 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
return embed
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text_weighted(self, text):
|
||||
def embed_text_weighted(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Get text embedding with weights.
|
||||
"""
|
||||
from .prompt_weighting import get_weighted_text_embeddings
|
||||
|
||||
return get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=text,
|
||||
|
@ -93,10 +198,10 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
def riffuse(
|
||||
self,
|
||||
inputs: InferenceInput,
|
||||
init_image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image = None,
|
||||
init_image: Image.Image,
|
||||
mask_image: T.Optional[Image.Image] = None,
|
||||
use_reweighting: bool = True,
|
||||
) -> PIL.Image.Image:
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Runs inference using interpolation with both img2img and text conditioning.
|
||||
|
||||
|
@ -113,8 +218,14 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
end = inputs.end
|
||||
|
||||
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
|
||||
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
|
||||
|
||||
# TODO(hayk): Always generate the seed on CPU?
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
|
||||
else:
|
||||
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
|
||||
|
||||
# Text encodings
|
||||
if use_reweighting:
|
||||
|
@ -123,25 +234,31 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
else:
|
||||
embed_start = self.embed_text(start.prompt)
|
||||
embed_end = self.embed_text(end.prompt)
|
||||
text_embedding = torch.lerp(embed_start, embed_end, alpha)
|
||||
|
||||
text_embedding = embed_start + alpha * (embed_end - embed_start)
|
||||
|
||||
# Image latents
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
init_image_torch = preprocess_image(init_image).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
|
||||
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
|
||||
# result is so close no matter the seed that it doesn't really add variety.
|
||||
generator = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
else:
|
||||
generator = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Prepare mask latent
|
||||
mask: T.Optional[torch.Tensor] = None
|
||||
if mask_image:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
|
||||
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
else:
|
||||
mask = None
|
||||
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
|
||||
outputs = self.interpolate_img2img(
|
||||
text_embeddings=text_embedding,
|
||||
|
@ -161,18 +278,18 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
@torch.no_grad()
|
||||
def interpolate_img2img(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
init_latents: torch.FloatTensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
init_latents: torch.Tensor,
|
||||
generator_a: torch.Generator,
|
||||
generator_b: torch.Generator,
|
||||
interpolate_alpha: float,
|
||||
mask: T.Optional[torch.FloatTensor] = None,
|
||||
mask: T.Optional[torch.Tensor] = None,
|
||||
strength_a: float = 0.8,
|
||||
strength_b: float = 0.8,
|
||||
num_inference_steps: T.Optional[int] = 50,
|
||||
guidance_scale: T.Optional[float] = 7.5,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
num_images_per_prompt: T.Optional[int] = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: T.Optional[float] = 0.0,
|
||||
output_type: T.Optional[str] = "pil",
|
||||
**kwargs,
|
||||
|
@ -198,11 +315,6 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
|
@ -251,11 +363,11 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
noise_b = torch.randn(
|
||||
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
|
||||
)
|
||||
noise = slerp(interpolate_alpha, noise_a, noise_b)
|
||||
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
|
||||
init_latents_orig = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
@ -295,7 +407,9 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_orig, noise, torch.tensor([t])
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
@ -311,62 +425,42 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
return dict(images=image, latents=latents, nsfw_content_detected=False)
|
||||
|
||||
|
||||
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
||||
def preprocess_image(image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess an image for the model.
|
||||
"""
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = image_np[None].transpose(0, 3, 1, 2)
|
||||
|
||||
image_torch = torch.from_numpy(image_np)
|
||||
|
||||
return 2.0 * image_torch - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a mask for the model.
|
||||
"""
|
||||
# Convert to grayscale
|
||||
mask = mask.convert("L")
|
||||
|
||||
# Resize to integer multiple of 32
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize(
|
||||
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
|
||||
)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
w, h = map(lambda x: x - x % 32, (w, h))
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
|
||||
|
||||
return mask
|
||||
# Convert to numpy array and rescale
|
||||
mask_np = np.array(mask).astype(np.float32) / 255.0
|
||||
|
||||
# Tile and transpose
|
||||
mask_np = np.tile(mask_np, (4, 1, 1))
|
||||
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
|
||||
def slerp(t, v0, v1, dot_threshold=0.9995):
|
||||
"""
|
||||
Helper function to spherically interpolate two arrays v1 v2.
|
||||
"""
|
||||
# Invert to repaint white and keep black
|
||||
mask_np = 1 - mask_np # repaint white, keep black
|
||||
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > dot_threshold:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
||||
return torch.from_numpy(mask_np)
|
||||
|
|
Loading…
Reference in New Issue