2023-12-27 02:04:33 -07:00
|
|
|
import logging
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import tqdm
|
|
|
|
from PIL import Image
|
|
|
|
|
2023-12-31 12:38:30 -07:00
|
|
|
from modules import images, shared, torch_utils
|
2023-12-27 02:04:33 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def upscale_without_tiling(model, img: Image.Image):
|
|
|
|
img = np.array(img)
|
|
|
|
img = img[:, :, ::-1]
|
|
|
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
|
|
|
img = torch.from_numpy(img).float()
|
2023-12-30 12:41:53 -07:00
|
|
|
|
2023-12-31 12:38:30 -07:00
|
|
|
param = torch_utils.get_param(model)
|
2023-12-30 15:20:30 -07:00
|
|
|
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
2023-12-30 12:41:53 -07:00
|
|
|
|
2023-12-27 02:04:33 -07:00
|
|
|
with torch.no_grad():
|
|
|
|
output = model(img)
|
2023-12-30 12:41:53 -07:00
|
|
|
|
2023-12-27 02:04:33 -07:00
|
|
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
|
output = 255. * np.moveaxis(output, 0, 2)
|
|
|
|
output = output.astype(np.uint8)
|
|
|
|
output = output[:, :, ::-1]
|
|
|
|
return Image.fromarray(output, 'RGB')
|
|
|
|
|
|
|
|
|
|
|
|
def upscale_with_model(
|
|
|
|
model: Callable[[torch.Tensor], torch.Tensor],
|
|
|
|
img: Image.Image,
|
|
|
|
*,
|
|
|
|
tile_size: int,
|
|
|
|
tile_overlap: int = 0,
|
|
|
|
desc="tiled upscale",
|
|
|
|
) -> Image.Image:
|
|
|
|
if tile_size <= 0:
|
|
|
|
logger.debug("Upscaling %s without tiling", img)
|
|
|
|
output = upscale_without_tiling(model, img)
|
|
|
|
logger.debug("=> %s", output)
|
|
|
|
return output
|
|
|
|
|
|
|
|
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
|
|
|
|
newtiles = []
|
|
|
|
|
|
|
|
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
|
|
|
|
for y, h, row in grid.tiles:
|
|
|
|
newrow = []
|
|
|
|
for x, w, tile in row:
|
|
|
|
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
|
|
|
output = upscale_without_tiling(model, tile)
|
|
|
|
scale_factor = output.width // tile.width
|
|
|
|
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
|
|
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
|
|
|
p.update(1)
|
|
|
|
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
|
|
|
|
|
|
|
newgrid = images.Grid(
|
|
|
|
newtiles,
|
|
|
|
tile_w=grid.tile_w * scale_factor,
|
|
|
|
tile_h=grid.tile_h * scale_factor,
|
|
|
|
image_w=grid.image_w * scale_factor,
|
|
|
|
image_h=grid.image_h * scale_factor,
|
|
|
|
overlap=grid.overlap * scale_factor,
|
|
|
|
)
|
|
|
|
return images.combine_grid(newgrid)
|
2023-12-30 13:53:49 -07:00
|
|
|
|
|
|
|
|
|
|
|
def tiled_upscale_2(
|
|
|
|
img,
|
|
|
|
model,
|
|
|
|
*,
|
|
|
|
tile_size: int,
|
|
|
|
tile_overlap: int,
|
|
|
|
scale: int,
|
|
|
|
device,
|
|
|
|
desc="Tiled upscale",
|
|
|
|
):
|
|
|
|
# Alternative implementation of `upscale_with_model` originally used by
|
|
|
|
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
|
|
|
|
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
|
|
|
# Pillow space without weighting.
|
|
|
|
b, c, h, w = img.size()
|
|
|
|
tile_size = min(tile_size, h, w)
|
|
|
|
|
|
|
|
if tile_size <= 0:
|
|
|
|
logger.debug("Upscaling %s without tiling", img.shape)
|
|
|
|
return model(img)
|
|
|
|
|
|
|
|
stride = tile_size - tile_overlap
|
|
|
|
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
|
|
|
|
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
|
|
|
|
result = torch.zeros(
|
|
|
|
b,
|
|
|
|
c,
|
|
|
|
h * scale,
|
|
|
|
w * scale,
|
|
|
|
device=device,
|
|
|
|
).type_as(img)
|
|
|
|
weights = torch.zeros_like(result)
|
|
|
|
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
|
|
|
|
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
|
|
|
|
for h_idx in h_idx_list:
|
|
|
|
if shared.state.interrupted or shared.state.skipped:
|
|
|
|
break
|
|
|
|
|
|
|
|
for w_idx in w_idx_list:
|
|
|
|
if shared.state.interrupted or shared.state.skipped:
|
|
|
|
break
|
|
|
|
|
|
|
|
in_patch = img[
|
|
|
|
...,
|
|
|
|
h_idx : h_idx + tile_size,
|
|
|
|
w_idx : w_idx + tile_size,
|
|
|
|
]
|
|
|
|
out_patch = model(in_patch)
|
|
|
|
|
|
|
|
result[
|
|
|
|
...,
|
|
|
|
h_idx * scale : (h_idx + tile_size) * scale,
|
|
|
|
w_idx * scale : (w_idx + tile_size) * scale,
|
|
|
|
].add_(out_patch)
|
|
|
|
|
|
|
|
out_patch_mask = torch.ones_like(out_patch)
|
|
|
|
|
|
|
|
weights[
|
|
|
|
...,
|
|
|
|
h_idx * scale : (h_idx + tile_size) * scale,
|
|
|
|
w_idx * scale : (w_idx + tile_size) * scale,
|
|
|
|
].add_(out_patch_mask)
|
|
|
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
output = result.div_(weights)
|
|
|
|
|
|
|
|
return output
|