load_spandrel_model: make `half` `prefer_half`

As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
This commit is contained in:
Aarni Koskela 2023-12-31 19:52:32 +02:00
parent 51f1cca852
commit 2cacbc124c
2 changed files with 15 additions and 7 deletions

View File

@ -139,23 +139,31 @@ def load_upscalers():
def load_spandrel_model( def load_spandrel_model(
path: str, path: str | os.PathLike,
*, *,
device: str | torch.device | None, device: str | torch.device | None,
half: bool = False, prefer_half: bool = False,
dtype: str | torch.dtype | None = None, dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None, expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor: ) -> spandrel.ModelDescriptor:
import spandrel import spandrel
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture: if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning( logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
) )
if half: half = False
if prefer_half:
if model_descriptor.supports_half:
model_descriptor.model.half() model_descriptor.model.half()
half = True
else:
logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype: if dtype:
model_descriptor.model.to(dtype=dtype) model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval() model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) logger.debug(
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
model_descriptor, path, device, half, dtype,
)
return model_descriptor return model_descriptor

View File

@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
model_descriptor = modelloader.load_spandrel_model( model_descriptor = modelloader.load_spandrel_model(
info.local_data_path, info.local_data_path,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
) )
return upscale_with_model( return upscale_with_model(