load_spandrel_model: make `half` `prefer_half`
As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not.
This commit is contained in:
parent
51f1cca852
commit
2cacbc124c
|
@ -139,23 +139,31 @@ def load_upscalers():
|
|||
|
||||
|
||||
def load_spandrel_model(
|
||||
path: str,
|
||||
path: str | os.PathLike,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
half: bool = False,
|
||||
prefer_half: bool = False,
|
||||
dtype: str | torch.dtype | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
) -> spandrel.ModelDescriptor:
|
||||
import spandrel
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||
logger.warning(
|
||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||
)
|
||||
if half:
|
||||
half = False
|
||||
if prefer_half:
|
||||
if model_descriptor.supports_half:
|
||||
model_descriptor.model.half()
|
||||
half = True
|
||||
else:
|
||||
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||
if dtype:
|
||||
model_descriptor.model.to(dtype=dtype)
|
||||
model_descriptor.model.eval()
|
||||
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
|
||||
logger.debug(
|
||||
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||
model_descriptor, path, device, half, dtype,
|
||||
)
|
||||
return model_descriptor
|
||||
|
|
|
@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||
)
|
||||
return upscale_with_model(
|
||||
|
|
Loading…
Reference in New Issue