Merge pull request #14484 from akx/swinir-resample-for-div8
Refactor Torch-space upscale fully out of ScuNET/SwinIR
This commit is contained in:
commit
51f1cca852
|
@ -1,13 +1,9 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
|
||||||
from modules.shared import opts
|
|
||||||
from modules.upscaler_utils import tiled_upscale_2
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
|
@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
img = upscaler_utils.upscale_2(
|
||||||
tile = opts.SCUNET_tile
|
img,
|
||||||
h, w = img.height, img.width
|
model,
|
||||||
np_img = np.array(img)
|
tile_size=shared.opts.SCUNET_tile,
|
||||||
np_img = np_img[:, :, ::-1] # RGB to BGR
|
tile_overlap=shared.opts.SCUNET_tile_overlap,
|
||||||
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
scale=1, # ScuNET is a denoising model, not an upscaler
|
||||||
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
desc='ScuNET',
|
||||||
|
)
|
||||||
if tile > h or tile > w:
|
|
||||||
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
|
||||||
_img[:, :, :h, :w] = torch_img # pad image
|
|
||||||
torch_img = _img
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
torch_output = tiled_upscale_2(
|
|
||||||
torch_img,
|
|
||||||
model,
|
|
||||||
tile_size=opts.SCUNET_tile,
|
|
||||||
tile_overlap=opts.SCUNET_tile_overlap,
|
|
||||||
scale=1,
|
|
||||||
device=devices.get_device_for('scunet'),
|
|
||||||
desc="ScuNET tiles",
|
|
||||||
).squeeze(0)
|
|
||||||
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
|
||||||
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
|
||||||
del torch_img, torch_output
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
return img
|
||||||
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
|
||||||
output = output[:, :, ::-1] # BGR to RGB
|
|
||||||
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
|
@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||||
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
||||||
from modules.shared import opts
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.upscaler_utils import tiled_upscale_2
|
|
||||||
|
|
||||||
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||||
|
|
||||||
|
@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler):
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
|
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
|
||||||
current_config = (model_file, opts.SWIN_tile)
|
current_config = (model_file, shared.opts.SWIN_tile)
|
||||||
|
|
||||||
device = self._get_device()
|
|
||||||
|
|
||||||
if self._cached_model_config == current_config:
|
if self._cached_model_config == current_config:
|
||||||
model = self._cached_model
|
model = self._cached_model
|
||||||
|
@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler):
|
||||||
self._cached_model = model
|
self._cached_model = model
|
||||||
self._cached_model_config = current_config
|
self._cached_model_config = current_config
|
||||||
|
|
||||||
img = upscale(
|
img = upscaler_utils.upscale_2(
|
||||||
img,
|
img,
|
||||||
model,
|
model,
|
||||||
tile=opts.SWIN_tile,
|
tile_size=shared.opts.SWIN_tile,
|
||||||
tile_overlap=opts.SWIN_tile_overlap,
|
tile_overlap=shared.opts.SWIN_tile_overlap,
|
||||||
device=device,
|
scale=4, # TODO: This was hard-coded before too...
|
||||||
|
desc="SwinIR",
|
||||||
)
|
)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return img
|
return img
|
||||||
|
@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler):
|
||||||
dtype=devices.dtype,
|
dtype=devices.dtype,
|
||||||
expected_architecture="SwinIR",
|
expected_architecture="SwinIR",
|
||||||
)
|
)
|
||||||
if getattr(opts, 'SWIN_torch_compile', False):
|
if getattr(shared.opts, 'SWIN_torch_compile', False):
|
||||||
try:
|
try:
|
||||||
model_descriptor.model.compile()
|
model_descriptor.model.compile()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler):
|
||||||
return devices.get_device_for('swinir')
|
return devices.get_device_for('swinir')
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
*,
|
|
||||||
tile: int,
|
|
||||||
tile_overlap: int,
|
|
||||||
window_size=8,
|
|
||||||
scale=4,
|
|
||||||
device,
|
|
||||||
):
|
|
||||||
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(device, dtype=devices.dtype)
|
|
||||||
with torch.no_grad(), devices.autocast():
|
|
||||||
_, _, h_old, w_old = img.size()
|
|
||||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
|
||||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
|
||||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
|
||||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
|
||||||
output = tiled_upscale_2(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
tile_size=tile,
|
|
||||||
tile_overlap=tile_overlap,
|
|
||||||
scale=scale,
|
|
||||||
device=device,
|
|
||||||
desc="SwinIR tiles",
|
|
||||||
)
|
|
||||||
output = output[..., : h_old * scale, : w_old * scale]
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
if output.ndim == 3:
|
|
||||||
output = np.transpose(
|
|
||||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
|
||||||
) # CHW-RGB to HCW-BGR
|
|
||||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
|
||||||
return Image.fromarray(output, "RGB")
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
|
@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upscale_without_tiling(model, img: Image.Image):
|
def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
|
||||||
img = np.array(img)
|
img = np.array(img.convert("RGB"))
|
||||||
img = img[:, :, ::-1]
|
img = img[:, :, ::-1] # flip RGB to BGR
|
||||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
|
||||||
img = torch.from_numpy(img).float()
|
img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
|
||||||
|
return torch.from_numpy(img)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
|
||||||
|
if tensor.ndim == 4:
|
||||||
|
# If we're given a tensor with a batch dimension, squeeze it out
|
||||||
|
# (but only if it's a batch of size 1).
|
||||||
|
if tensor.shape[0] != 1:
|
||||||
|
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
|
||||||
|
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
|
||||||
|
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
|
||||||
|
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
|
||||||
|
arr = arr.astype(np.uint8)
|
||||||
|
arr = arr[:, :, ::-1] # flip BGR to RGB
|
||||||
|
return Image.fromarray(arr, "RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Upscale a given PIL image using the given model.
|
||||||
|
"""
|
||||||
param = torch_utils.get_param(model)
|
param = torch_utils.get_param(model)
|
||||||
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
|
||||||
|
tensor = tensor.to(device=param.device, dtype=param.dtype)
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
return torch_bgr_to_pil_image(model(tensor))
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
return Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
|
|
||||||
def upscale_with_model(
|
def upscale_with_model(
|
||||||
|
@ -40,7 +57,7 @@ def upscale_with_model(
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
if tile_size <= 0:
|
if tile_size <= 0:
|
||||||
logger.debug("Upscaling %s without tiling", img)
|
logger.debug("Upscaling %s without tiling", img)
|
||||||
output = upscale_without_tiling(model, img)
|
output = upscale_pil_patch(model, img)
|
||||||
logger.debug("=> %s", output)
|
logger.debug("=> %s", output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -52,7 +69,7 @@ def upscale_with_model(
|
||||||
newrow = []
|
newrow = []
|
||||||
for x, w, tile in row:
|
for x, w, tile in row:
|
||||||
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
||||||
output = upscale_without_tiling(model, tile)
|
output = upscale_pil_patch(model, tile)
|
||||||
scale_factor = output.width // tile.width
|
scale_factor = output.width // tile.width
|
||||||
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
||||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||||
|
@ -71,19 +88,22 @@ def upscale_with_model(
|
||||||
|
|
||||||
|
|
||||||
def tiled_upscale_2(
|
def tiled_upscale_2(
|
||||||
img,
|
img: torch.Tensor,
|
||||||
model,
|
model,
|
||||||
*,
|
*,
|
||||||
tile_size: int,
|
tile_size: int,
|
||||||
tile_overlap: int,
|
tile_overlap: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
device,
|
|
||||||
desc="Tiled upscale",
|
desc="Tiled upscale",
|
||||||
):
|
):
|
||||||
# Alternative implementation of `upscale_with_model` originally used by
|
# Alternative implementation of `upscale_with_model` originally used by
|
||||||
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
|
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
|
||||||
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
||||||
# Pillow space without weighting.
|
# Pillow space without weighting.
|
||||||
|
|
||||||
|
# Grab the device the model is on, and use it.
|
||||||
|
device = torch_utils.get_param(model).device
|
||||||
|
|
||||||
b, c, h, w = img.size()
|
b, c, h, w = img.size()
|
||||||
tile_size = min(tile_size, h, w)
|
tile_size = min(tile_size, h, w)
|
||||||
|
|
||||||
|
@ -100,7 +120,8 @@ def tiled_upscale_2(
|
||||||
h * scale,
|
h * scale,
|
||||||
w * scale,
|
w * scale,
|
||||||
device=device,
|
device=device,
|
||||||
).type_as(img)
|
dtype=img.dtype,
|
||||||
|
)
|
||||||
weights = torch.zeros_like(result)
|
weights = torch.zeros_like(result)
|
||||||
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
|
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
|
||||||
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
|
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
|
||||||
|
@ -112,11 +133,13 @@ def tiled_upscale_2(
|
||||||
if shared.state.interrupted or shared.state.skipped:
|
if shared.state.interrupted or shared.state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Only move this patch to the device if it's not already there.
|
||||||
in_patch = img[
|
in_patch = img[
|
||||||
...,
|
...,
|
||||||
h_idx : h_idx + tile_size,
|
h_idx : h_idx + tile_size,
|
||||||
w_idx : w_idx + tile_size,
|
w_idx : w_idx + tile_size,
|
||||||
]
|
].to(device=device)
|
||||||
|
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
|
|
||||||
result[
|
result[
|
||||||
|
@ -138,3 +161,29 @@ def tiled_upscale_2(
|
||||||
output = result.div_(weights)
|
output = result.div_(weights)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_2(
|
||||||
|
img: Image.Image,
|
||||||
|
model,
|
||||||
|
*,
|
||||||
|
tile_size: int,
|
||||||
|
tile_overlap: int,
|
||||||
|
scale: int,
|
||||||
|
desc: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
||||||
|
"""
|
||||||
|
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = tiled_upscale_2(
|
||||||
|
tensor,
|
||||||
|
model,
|
||||||
|
tile_size=tile_size,
|
||||||
|
tile_overlap=tile_overlap,
|
||||||
|
scale=scale,
|
||||||
|
desc=desc,
|
||||||
|
)
|
||||||
|
return torch_bgr_to_pil_image(output)
|
||||||
|
|
Loading…
Reference in New Issue