commit
3f7f61e541
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
||||||
|
@ -50,7 +51,7 @@ class UpscalerSwinIR(Upscaler):
|
||||||
model,
|
model,
|
||||||
tile_size=shared.opts.SWIN_tile,
|
tile_size=shared.opts.SWIN_tile,
|
||||||
tile_overlap=shared.opts.SWIN_tile_overlap,
|
tile_overlap=shared.opts.SWIN_tile_overlap,
|
||||||
scale=4, # TODO: This was hard-coded before too...
|
scale=model.scale,
|
||||||
desc="SwinIR",
|
desc="SwinIR",
|
||||||
)
|
)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler):
|
||||||
model_descriptor = modelloader.load_spandrel_model(
|
model_descriptor = modelloader.load_spandrel_model(
|
||||||
filename,
|
filename,
|
||||||
device=self._get_device(),
|
device=self._get_device(),
|
||||||
dtype=devices.dtype,
|
prefer_half=(devices.dtype == torch.float16),
|
||||||
expected_architecture="SwinIR",
|
expected_architecture="SwinIR",
|
||||||
)
|
)
|
||||||
if getattr(shared.opts, 'SWIN_torch_compile', False):
|
if getattr(shared.opts, 'SWIN_torch_compile', False):
|
||||||
|
|
|
@ -94,6 +94,7 @@ def tiled_upscale_2(
|
||||||
tile_size: int,
|
tile_size: int,
|
||||||
tile_overlap: int,
|
tile_overlap: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
|
device: torch.device,
|
||||||
desc="Tiled upscale",
|
desc="Tiled upscale",
|
||||||
):
|
):
|
||||||
# Alternative implementation of `upscale_with_model` originally used by
|
# Alternative implementation of `upscale_with_model` originally used by
|
||||||
|
@ -101,9 +102,6 @@ def tiled_upscale_2(
|
||||||
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
||||||
# Pillow space without weighting.
|
# Pillow space without weighting.
|
||||||
|
|
||||||
# Grab the device the model is on, and use it.
|
|
||||||
device = torch_utils.get_param(model).device
|
|
||||||
|
|
||||||
b, c, h, w = img.size()
|
b, c, h, w = img.size()
|
||||||
tile_size = min(tile_size, h, w)
|
tile_size = min(tile_size, h, w)
|
||||||
|
|
||||||
|
@ -175,7 +173,8 @@ def upscale_2(
|
||||||
"""
|
"""
|
||||||
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
||||||
"""
|
"""
|
||||||
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
|
param = torch_utils.get_param(model)
|
||||||
|
tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = tiled_upscale_2(
|
output = tiled_upscale_2(
|
||||||
|
@ -185,5 +184,6 @@ def upscale_2(
|
||||||
tile_overlap=tile_overlap,
|
tile_overlap=tile_overlap,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
desc=desc,
|
desc=desc,
|
||||||
|
device=param.device,
|
||||||
)
|
)
|
||||||
return torch_bgr_to_pil_image(output)
|
return torch_bgr_to_pil_image(output)
|
||||||
|
|
Loading…
Reference in New Issue