Merge pull request #14524 from akx/fix-swinir-issues

Fix SwinIR issues
This commit is contained in:
AUTOMATIC1111 2024-01-04 11:17:20 +03:00 committed by GitHub
commit 3f7f61e541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import sys import sys
import torch
from PIL import Image from PIL import Image
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
@ -50,7 +51,7 @@ class UpscalerSwinIR(Upscaler):
model, model,
tile_size=shared.opts.SWIN_tile, tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap, tile_overlap=shared.opts.SWIN_tile_overlap,
scale=4, # TODO: This was hard-coded before too... scale=model.scale,
desc="SwinIR", desc="SwinIR",
) )
devices.torch_gc() devices.torch_gc()
@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler):
model_descriptor = modelloader.load_spandrel_model( model_descriptor = modelloader.load_spandrel_model(
filename, filename,
device=self._get_device(), device=self._get_device(),
dtype=devices.dtype, prefer_half=(devices.dtype == torch.float16),
expected_architecture="SwinIR", expected_architecture="SwinIR",
) )
if getattr(shared.opts, 'SWIN_torch_compile', False): if getattr(shared.opts, 'SWIN_torch_compile', False):

View File

@ -94,6 +94,7 @@ def tiled_upscale_2(
tile_size: int, tile_size: int,
tile_overlap: int, tile_overlap: int,
scale: int, scale: int,
device: torch.device,
desc="Tiled upscale", desc="Tiled upscale",
): ):
# Alternative implementation of `upscale_with_model` originally used by # Alternative implementation of `upscale_with_model` originally used by
@ -101,9 +102,6 @@ def tiled_upscale_2(
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting. # Pillow space without weighting.
# Grab the device the model is on, and use it.
device = torch_utils.get_param(model).device
b, c, h, w = img.size() b, c, h, w = img.size()
tile_size = min(tile_size, h, w) tile_size = min(tile_size, h, w)
@ -175,7 +173,8 @@ def upscale_2(
""" """
Convenience wrapper around `tiled_upscale_2` that handles PIL images. Convenience wrapper around `tiled_upscale_2` that handles PIL images.
""" """
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension param = torch_utils.get_param(model)
tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
with torch.no_grad(): with torch.no_grad():
output = tiled_upscale_2( output = tiled_upscale_2(
@ -185,5 +184,6 @@ def upscale_2(
tile_overlap=tile_overlap, tile_overlap=tile_overlap,
scale=scale, scale=scale,
desc=desc, desc=desc,
device=param.device,
) )
return torch_bgr_to_pil_image(output) return torch_bgr_to_pil_image(output)