Deduplicate tiled inference code from SwinIR/ScuNET

This commit is contained in:
Aarni Koskela 2023-12-30 22:53:49 +02:00
parent ce21840a04
commit 6f86b62a1b
3 changed files with 87 additions and 97 deletions

View File

@ -3,12 +3,11 @@ import sys
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from modules.shared import opts from modules.shared import opts
from modules.upscaler_utils import tiled_upscale_2
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc() devices.torch_gc()
@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
_img[:, :, :h, :w] = torch_img # pad image _img[:, :, :h, :w] = torch_img # pad image
torch_img = _img torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0) with torch.no_grad():
torch_output = tiled_upscale_2(
torch_img,
model,
tile_size=opts.SCUNET_tile,
tile_overlap=opts.SCUNET_tile_overlap,
scale=1,
device=devices.get_device_for('scunet'),
desc="ScuNET tiles",
).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output del torch_img, torch_output

View File

@ -4,11 +4,11 @@ import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
@ -110,14 +110,14 @@ def upscale(
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference( output = tiled_upscale_2(
img, img,
model, model,
tile=tile, tile_size=tile,
tile_overlap=tile_overlap, tile_overlap=tile_overlap,
window_size=window_size,
scale=scale, scale=scale,
device=device, device=device,
desc="SwinIR tiles",
) )
output = output[..., : h_old * scale, : w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@ -129,53 +129,6 @@ def upscale(
return Image.fromarray(output, "RGB") return Image.fromarray(output, "RGB")
def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr

View File

@ -6,7 +6,7 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
from modules import images from modules import images, shared
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,3 +68,73 @@ def upscale_with_model(
overlap=grid.overlap * scale_factor, overlap=grid.overlap * scale_factor,
) )
return images.combine_grid(newgrid) return images.combine_grid(newgrid)
def tiled_upscale_2(
img,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img.shape)
return model(img)
stride = tile_size - tile_overlap
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
result = torch.zeros(
b,
c,
h * scale,
w * scale,
device=device,
).type_as(img)
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
for h_idx in h_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
for w_idx in w_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
]
out_patch = model(in_patch)
result[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch)
out_patch_mask = torch.ones_like(out_patch)
weights[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch_mask)
pbar.update(1)
output = result.div_(weights)
return output