upscale_2: cast image to model's dtype
This commit is contained in:
parent
3d31d5c27b
commit
62470ee234
|
@ -94,6 +94,7 @@ def tiled_upscale_2(
|
|||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
scale: int,
|
||||
device: torch.device,
|
||||
desc="Tiled upscale",
|
||||
):
|
||||
# Alternative implementation of `upscale_with_model` originally used by
|
||||
|
@ -101,9 +102,6 @@ def tiled_upscale_2(
|
|||
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
||||
# Pillow space without weighting.
|
||||
|
||||
# Grab the device the model is on, and use it.
|
||||
device = torch_utils.get_param(model).device
|
||||
|
||||
b, c, h, w = img.size()
|
||||
tile_size = min(tile_size, h, w)
|
||||
|
||||
|
@ -175,7 +173,8 @@ def upscale_2(
|
|||
"""
|
||||
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
||||
"""
|
||||
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
|
||||
param = torch_utils.get_param(model)
|
||||
tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
|
||||
|
||||
with torch.no_grad():
|
||||
output = tiled_upscale_2(
|
||||
|
@ -185,5 +184,6 @@ def upscale_2(
|
|||
tile_overlap=tile_overlap,
|
||||
scale=scale,
|
||||
desc=desc,
|
||||
device=param.device,
|
||||
)
|
||||
return torch_bgr_to_pil_image(output)
|
||||
|
|
Loading…
Reference in New Issue