diff --git a/riffusion/external/README.md b/riffusion/external/README.md new file mode 100644 index 0000000..698f0b6 --- /dev/null +++ b/riffusion/external/README.md @@ -0,0 +1,3 @@ +# external + +This package contains scripts and tools from external sources. diff --git a/riffusion/external/__init__.py b/riffusion/external/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/prompt_weighting.py b/riffusion/external/prompt_weighting.py similarity index 99% rename from riffusion/prompt_weighting.py rename to riffusion/external/prompt_weighting.py index b544347..5c71486 100644 --- a/riffusion/prompt_weighting.py +++ b/riffusion/external/prompt_weighting.py @@ -5,6 +5,9 @@ This code is taken from the diffusers community pipeline: License: Apache 2.0 """ +# ruff: noqa +# mypy: ignore-errors + import re from typing import List, Optional, Union @@ -193,7 +196,7 @@ def get_unweighted_text_embeddings( text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, -): +) -> torch.FloatTensor: """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. @@ -239,7 +242,7 @@ def get_weighted_text_embeddings( skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, **kwargs, -): +) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]: r""" Prompts can be assigned with local weights using brackets. For example, prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', @@ -269,8 +272,6 @@ def get_weighted_text_embeddings( if not skip_parsing: prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) - print(f"tokens: {prompt_tokens}") - print(f"weights: {prompt_weights}") if uncond_prompt is not None: if isinstance(uncond_prompt, str): diff --git a/riffusion/riffusion_pipeline.py b/riffusion/riffusion_pipeline.py index 9352d77..997cd85 100644 --- a/riffusion/riffusion_pipeline.py +++ b/riffusion/riffusion_pipeline.py @@ -1,13 +1,15 @@ """ Riffusion inference pipeline. """ +from __future__ import annotations +import dataclasses import functools import inspect import typing as T import numpy as np -import PIL +from PIL import Image import torch from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -15,9 +17,12 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import logging +from huggingface_hub import hf_hub_download from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from .datatypes import InferenceInput +from riffusion.datatypes import InferenceInput +from riffusion.external.prompt_weighting import get_weighted_text_embeddings +from riffusion.util import torch_util logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -56,8 +61,110 @@ class RiffusionPipeline(DiffusionPipeline): feature_extractor=feature_extractor, ) + @classmethod + def load_checkpoint( + cls, + checkpoint: str, + use_traced_unet: bool = True, + channels_last: bool = False, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + ) -> RiffusionPipeline: + """ + Load the riffusion model pipeline. + + Args: + checkpoint: Model checkpoint on disk in diffusers format + use_traced_unet: Whether to use the traced unet for speedups + device: Device to load the model on + channels_last: Whether to use channels_last memory format + """ + device = torch_util.check_device(device) + + if device == "cpu" or device.lower().startswith("mps"): + print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") + dtype = torch.float32 + + pipeline = RiffusionPipeline.from_pretrained( + checkpoint, + revision="main", + torch_dtype=dtype, + # Disable the NSFW filter, causes incorrect false positives + # TODO(hayk): Disable the "you have passed a non-standard module" warning from this. + safety_checker=lambda images, **kwargs: (images, False), + # Optionally attempt to use less memory + low_cpu_mem_usage=False, + ).to(device) + + if channels_last: + pipeline.unet.to(memory_format=torch.channels_last) + + # Optionally load a traced unet + if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet: + traced_unet = cls.load_traced_unet( + checkpoint=checkpoint, + subfolder="unet_traced", + filename="unet_traced.pt", + in_channels=pipeline.unet.in_channels, + dtype=dtype, + device=device, + ) + + if traced_unet is not None: + pipeline.unet = traced_unet + + model = pipeline.to(device) + + return model + + @staticmethod + def load_traced_unet( + checkpoint: str, + subfolder: str, + filename: str, + in_channels: int, + dtype: torch.dtype, + device: str = "cuda", + ) -> T.Optional[torch.nn.Module]: + """ + Load a traced unet from the huggingface hub. This can improve performance. + """ + if device == "cpu" or device.lower().startswith("mps"): + print("WARNING: Traced UNet only available for CUDA, skipping") + return None + + # Download and load the traced unet + unet_file = hf_hub_download( + checkpoint, + subfolder=subfolder, + filename=filename, + ) + unet_traced = torch.jit.load(unet_file) + + # Wrap it in a torch module + class TracedUNet(torch.nn.Module): + @dataclasses.dataclass + class UNet2DConditionOutput: + sample: torch.FloatTensor + + def __init__(self): + super().__init__() + self.in_channels = device + self.device = device + self.dtype = dtype + + def forward(self, latent_model_input, t, encoder_hidden_states): + sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] + return self.UNet2DConditionOutput(sample=sample) + + return TracedUNet() + + @property + def device(self) -> str: + return str(self.vae.device) + @functools.lru_cache() - def embed_text(self, text): + def embed_text(self, text) -> torch.FloatTensor: """ Takes in text and turns it into text embeddings. """ @@ -73,12 +180,10 @@ class RiffusionPipeline(DiffusionPipeline): return embed @functools.lru_cache() - def embed_text_weighted(self, text): + def embed_text_weighted(self, text) -> torch.FloatTensor: """ Get text embedding with weights. """ - from .prompt_weighting import get_weighted_text_embeddings - return get_weighted_text_embeddings( pipe=self, prompt=text, @@ -93,10 +198,10 @@ class RiffusionPipeline(DiffusionPipeline): def riffuse( self, inputs: InferenceInput, - init_image: PIL.Image.Image, - mask_image: PIL.Image.Image = None, + init_image: Image.Image, + mask_image: T.Optional[Image.Image] = None, use_reweighting: bool = True, - ) -> PIL.Image.Image: + ) -> Image.Image: """ Runs inference using interpolation with both img2img and text conditioning. @@ -113,8 +218,14 @@ class RiffusionPipeline(DiffusionPipeline): end = inputs.end guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha - generator_start = torch.Generator(device=self.device).manual_seed(start.seed) - generator_end = torch.Generator(device=self.device).manual_seed(end.seed) + + # TODO(hayk): Always generate the seed on CPU? + if self.device.lower().startswith("mps"): + generator_start = torch.Generator(device="cpu").manual_seed(start.seed) + generator_end = torch.Generator(device="cpu").manual_seed(end.seed) + else: + generator_start = torch.Generator(device=self.device).manual_seed(start.seed) + generator_end = torch.Generator(device=self.device).manual_seed(end.seed) # Text encodings if use_reweighting: @@ -123,25 +234,31 @@ class RiffusionPipeline(DiffusionPipeline): else: embed_start = self.embed_text(start.prompt) embed_end = self.embed_text(end.prompt) - text_embedding = torch.lerp(embed_start, embed_end, alpha) + + text_embedding = embed_start + alpha * (embed_end - embed_start) # Image latents - init_image = preprocess_image(init_image) - init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype) + init_image_torch = preprocess_image(init_image).to( + device=self.device, dtype=embed_start.dtype + ) init_latent_dist = self.vae.encode(init_image_torch).latent_dist # TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The # result is so close no matter the seed that it doesn't really add variety. - generator = torch.Generator(device=self.device).manual_seed(start.seed) + if self.device.lower().startswith("mps"): + generator = torch.Generator(device="cpu").manual_seed(start.seed) + else: + generator = torch.Generator(device=self.device).manual_seed(start.seed) + init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents # Prepare mask latent + mask: T.Optional[torch.Tensor] = None if mask_image: vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor) - mask = mask_image.to(device=self.device, dtype=embed_start.dtype) - else: - mask = None + mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to( + device=self.device, dtype=embed_start.dtype + ) outputs = self.interpolate_img2img( text_embeddings=text_embedding, @@ -161,18 +278,18 @@ class RiffusionPipeline(DiffusionPipeline): @torch.no_grad() def interpolate_img2img( self, - text_embeddings: torch.FloatTensor, - init_latents: torch.FloatTensor, + text_embeddings: torch.Tensor, + init_latents: torch.Tensor, generator_a: torch.Generator, generator_b: torch.Generator, interpolate_alpha: float, - mask: T.Optional[torch.FloatTensor] = None, + mask: T.Optional[torch.Tensor] = None, strength_a: float = 0.8, strength_b: float = 0.8, - num_inference_steps: T.Optional[int] = 50, - guidance_scale: T.Optional[float] = 7.5, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None, - num_images_per_prompt: T.Optional[int] = 1, + num_images_per_prompt: int = 1, eta: T.Optional[float] = 0.0, output_type: T.Optional[str] = "pil", **kwargs, @@ -198,11 +315,6 @@ class RiffusionPipeline(DiffusionPipeline): if do_classifier_free_guidance: if negative_prompt is None: uncond_tokens = [""] - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): @@ -251,11 +363,11 @@ class RiffusionPipeline(DiffusionPipeline): noise_b = torch.randn( init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype ) - noise = slerp(interpolate_alpha, noise_a, noise_b) + noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b) init_latents_orig = init_latents init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # prepare extra kwargs for the scheduler step, since not all schedulers have the same args # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] @@ -295,7 +407,9 @@ class RiffusionPipeline(DiffusionPipeline): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if mask is not None: - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise, torch.tensor([t]) + ) # import ipdb; ipdb.set_trace() latents = (init_latents_proper * mask) + (latents * (1 - mask)) @@ -311,62 +425,42 @@ class RiffusionPipeline(DiffusionPipeline): return dict(images=image, latents=latents, nsfw_content_detected=False) -def preprocess_image(image: PIL.Image.Image) -> torch.Tensor: +def preprocess_image(image: Image.Image) -> torch.Tensor: """ Preprocess an image for the model. """ w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + image = image.resize((w, h), resample=Image.LANCZOS) + + image_np = np.array(image).astype(np.float32) / 255.0 + image_np = image_np[None].transpose(0, 3, 1, 2) + + image_torch = torch.from_numpy(image_np) + + return 2.0 * image_torch - 1.0 -def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor: +def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor: """ Preprocess a mask for the model. """ + # Convert to grayscale mask = mask.convert("L") + + # Resize to integer multiple of 32 w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize( - (w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST - ) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) + w, h = map(lambda x: x - x % 32, (w, h)) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST) - return mask + # Convert to numpy array and rescale + mask_np = np.array(mask).astype(np.float32) / 255.0 + # Tile and transpose + mask_np = np.tile(mask_np, (4, 1, 1)) + mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do? -def slerp(t, v0, v1, dot_threshold=0.9995): - """ - Helper function to spherically interpolate two arrays v1 v2. - """ + # Invert to repaint white and keep black + mask_np = 1 - mask_np # repaint white, keep black - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - input_device = v0.device - v0 = v0.cpu().numpy() - v1 = v1.cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > dot_threshold: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(input_device) - - return v2 + return torch.from_numpy(mask_np)