New Pipeline: Tiled-upscaling with depth perception to avoid blurry spots (#1615)
* added first version of the tiled upscaling pipeline * reformatted to pass code quality tests
This commit is contained in:
parent
75d53cc839
commit
67e2f95cc4
|
@ -0,0 +1,298 @@
|
|||
# Copyright 2022 Peter Willemsen <peter@codebuffet.co>. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
|
||||
size_x = size[0] - overlap_pixels * 2
|
||||
size_y = size[1] - overlap_pixels * 2
|
||||
for letter in ["l", "r"]:
|
||||
if letter in remove_borders:
|
||||
size_x += overlap_pixels
|
||||
for letter in ["t", "b"]:
|
||||
if letter in remove_borders:
|
||||
size_y += overlap_pixels
|
||||
mask = np.ones((size_y, size_x), dtype=np.uint8) * 255
|
||||
mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)
|
||||
|
||||
if "l" in remove_borders:
|
||||
mask = mask[:, overlap_pixels : mask.shape[1]]
|
||||
if "r" in remove_borders:
|
||||
mask = mask[:, 0 : mask.shape[1] - overlap_pixels]
|
||||
if "t" in remove_borders:
|
||||
mask = mask[overlap_pixels : mask.shape[0], :]
|
||||
if "b" in remove_borders:
|
||||
mask = mask[0 : mask.shape[0] - overlap_pixels, :]
|
||||
return mask
|
||||
|
||||
|
||||
def clamp(n, smallest, largest):
|
||||
return max(smallest, min(n, largest))
|
||||
|
||||
|
||||
def clamp_rect(rect: [int], min: [int], max: [int]):
|
||||
return (
|
||||
clamp(rect[0], min[0], max[0]),
|
||||
clamp(rect[1], min[1], max[1]),
|
||||
clamp(rect[2], min[0], max[0]),
|
||||
clamp(rect[3], min[1], max[1]),
|
||||
)
|
||||
|
||||
|
||||
def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):
|
||||
rect = list(rect)
|
||||
rect[0] -= overlap
|
||||
rect[1] -= overlap
|
||||
rect[2] += overlap
|
||||
rect[3] += overlap
|
||||
rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])
|
||||
return rect
|
||||
|
||||
|
||||
def squeeze_tile(tile, original_image, original_slice, slice_x):
|
||||
result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))
|
||||
result.paste(
|
||||
original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(
|
||||
(slice_x, 0, slice_x + original_slice, tile.size[1])
|
||||
),
|
||||
(0, 0),
|
||||
)
|
||||
result.paste(tile, (original_slice, 0))
|
||||
return result
|
||||
|
||||
|
||||
def unsqueeze_tile(tile, original_image_slice):
|
||||
crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])
|
||||
tile = tile.crop(crop_rect)
|
||||
return tile
|
||||
|
||||
|
||||
def next_divisible(n, d):
|
||||
divisor = n % d
|
||||
return n - divisor
|
||||
|
||||
|
||||
class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
r"""
|
||||
Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute
|
||||
to create gigantic images.
|
||||
|
||||
This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
low_res_scheduler ([`SchedulerMixin`]):
|
||||
A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
|
||||
[`DDPMScheduler`].
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
max_noise_level=max_noise_level,
|
||||
)
|
||||
|
||||
def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):
|
||||
torch.manual_seed(0)
|
||||
crop_rect = (
|
||||
min(image.size[0] - (tile_size + original_image_slice), x * tile_size),
|
||||
min(image.size[1] - (tile_size + original_image_slice), y * tile_size),
|
||||
min(image.size[0], (x + 1) * tile_size),
|
||||
min(image.size[1], (y + 1) * tile_size),
|
||||
)
|
||||
crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)
|
||||
tile = image.crop(crop_rect_with_overlap)
|
||||
translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]
|
||||
translated_slice_x = translated_slice_x - (original_image_slice / 2)
|
||||
translated_slice_x = max(0, translated_slice_x)
|
||||
to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)
|
||||
orig_input_size = to_input.size
|
||||
to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)
|
||||
upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]
|
||||
upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)
|
||||
upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)
|
||||
upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)
|
||||
remove_borders = []
|
||||
if x == 0:
|
||||
remove_borders.append("l")
|
||||
elif crop_rect[2] == image.size[0]:
|
||||
remove_borders.append("r")
|
||||
if y == 0:
|
||||
remove_borders.append("t")
|
||||
elif crop_rect[3] == image.size[1]:
|
||||
remove_borders.append("b")
|
||||
transparency_mask = Image.fromarray(
|
||||
make_transparency_mask(
|
||||
(upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders
|
||||
),
|
||||
mode="L",
|
||||
)
|
||||
final_image.paste(
|
||||
upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 50,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
tile_size: int = 128,
|
||||
tile_border: int = 32,
|
||||
original_image_slice: int = 32,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
|
||||
`Image`, or tensor representing an image batch which will be upscaled. *
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
tile_size (`int`, *optional*):
|
||||
The size of the tiles. Too big can result in an OOM-error.
|
||||
tile_border (`int`, *optional*):
|
||||
The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).
|
||||
original_image_slice (`int`, *optional*):
|
||||
The amount of pixels of the original image to calculate with the current tile (bigger means more depth
|
||||
is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).
|
||||
callback (`Callable`, *optional*):
|
||||
A function that take a callback function with a single argument, a dict,
|
||||
that contains the (partially) processed image under "image",
|
||||
as well as the progress (0 to 1, where 1 is completed) under "progress".
|
||||
|
||||
Returns: A PIL.Image that is 4 times larger than the original input image.
|
||||
|
||||
"""
|
||||
|
||||
final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))
|
||||
tcx = math.ceil(image.size[0] / tile_size)
|
||||
tcy = math.ceil(image.size[1] / tile_size)
|
||||
total_tile_count = tcx * tcy
|
||||
current_count = 0
|
||||
for y in range(tcy):
|
||||
for x in range(tcx):
|
||||
self._process_tile(
|
||||
original_image_slice,
|
||||
x,
|
||||
y,
|
||||
tile_size,
|
||||
tile_border,
|
||||
image,
|
||||
final_image,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
noise_level=noise_level,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
current_count += 1
|
||||
if callback is not None:
|
||||
callback({"progress": current_count / total_tile_count, "image": final_image})
|
||||
return final_image
|
||||
|
||||
|
||||
def main():
|
||||
# Run a demo
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
|
||||
|
||||
def callback(obj):
|
||||
print(f"progress: {obj['progress']:.4f}")
|
||||
obj["image"].save("diffusers_library_progress.jpg")
|
||||
|
||||
final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)
|
||||
final_image.save("diffusers_library.jpg")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue