load_spandrel_model: make `half` `prefer_half`
As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not.
This commit is contained in:
parent
51f1cca852
commit
2cacbc124c
|
@ -139,23 +139,31 @@ def load_upscalers():
|
||||||
|
|
||||||
|
|
||||||
def load_spandrel_model(
|
def load_spandrel_model(
|
||||||
path: str,
|
path: str | os.PathLike,
|
||||||
*,
|
*,
|
||||||
device: str | torch.device | None,
|
device: str | torch.device | None,
|
||||||
half: bool = False,
|
prefer_half: bool = False,
|
||||||
dtype: str | torch.dtype | None = None,
|
dtype: str | torch.dtype | None = None,
|
||||||
expected_architecture: str | None = None,
|
expected_architecture: str | None = None,
|
||||||
) -> spandrel.ModelDescriptor:
|
) -> spandrel.ModelDescriptor:
|
||||||
import spandrel
|
import spandrel
|
||||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
|
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||||
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||||
)
|
)
|
||||||
if half:
|
half = False
|
||||||
|
if prefer_half:
|
||||||
|
if model_descriptor.supports_half:
|
||||||
model_descriptor.model.half()
|
model_descriptor.model.half()
|
||||||
|
half = True
|
||||||
|
else:
|
||||||
|
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||||
if dtype:
|
if dtype:
|
||||||
model_descriptor.model.to(dtype=dtype)
|
model_descriptor.model.to(dtype=dtype)
|
||||||
model_descriptor.model.eval()
|
model_descriptor.model.eval()
|
||||||
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
|
logger.debug(
|
||||||
|
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||||
|
model_descriptor, path, device, half, dtype,
|
||||||
|
)
|
||||||
return model_descriptor
|
return model_descriptor
|
||||||
|
|
|
@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||||
model_descriptor = modelloader.load_spandrel_model(
|
model_descriptor = modelloader.load_spandrel_model(
|
||||||
info.local_data_path,
|
info.local_data_path,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||||
)
|
)
|
||||||
return upscale_with_model(
|
return upscale_with_model(
|
||||||
|
|
Loading…
Reference in New Issue