Deduplicate tiled inference code from SwinIR/ScuNET

This commit is contained in:
Aarni Koskela 2023-12-30 22:53:49 +02:00
parent ce21840a04
commit 6f86b62a1b
3 changed files with 87 additions and 97 deletions

View File

@ -3,12 +3,11 @@ import sys
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from modules.shared import opts
from modules.upscaler_utils import tiled_upscale_2
class UpscalerScuNET(modules.upscaler.Upscaler):
@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc()
@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
with torch.no_grad():
torch_output = tiled_upscale_2(
torch_img,
model,
tile_size=opts.SCUNET_tile,
tile_overlap=opts.SCUNET_tile_overlap,
scale=1,
device=devices.get_device_for('scunet'),
desc="ScuNET tiles",
).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output

View File

@ -4,11 +4,11 @@ import sys
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
@ -110,14 +110,14 @@ def upscale(
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(
output = tiled_upscale_2(
img,
model,
tile=tile,
tile_size=tile,
tile_overlap=tile_overlap,
window_size=window_size,
scale=scale,
device=device,
desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@ -129,53 +129,6 @@ def upscale(
return Image.fromarray(output, "RGB")
def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def on_ui_settings():
import gradio as gr

View File

@ -6,7 +6,7 @@ import torch
import tqdm
from PIL import Image
from modules import images
from modules import images, shared
logger = logging.getLogger(__name__)
@ -68,3 +68,73 @@ def upscale_with_model(
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
def tiled_upscale_2(
img,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img.shape)
return model(img)
stride = tile_size - tile_overlap
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
result = torch.zeros(
b,
c,
h * scale,
w * scale,
device=device,
).type_as(img)
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
for h_idx in h_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
for w_idx in w_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
]
out_patch = model(in_patch)
result[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch)
out_patch_mask = torch.ones_like(out_patch)
weights[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch_mask)
pbar.update(1)
output = result.div_(weights)
return output