From f8fb74b93ab3f78dcec05e52e669b3b89b3a3b26 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jul 2024 08:43:48 +0300 Subject: [PATCH] Bump Spandrel to 0.3.4; add spandrel-extra-arches for CodeFormer --- modules/gfpgan_model.py | 4 +--- modules/modelloader.py | 32 +++++++++++++++++++++++++++++--- requirements_versions.txt | 3 ++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 445b04092..01ef899e4 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,13 +36,11 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - model = modelloader.load_spandrel_model( + return modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 5421e59b0..36e7415af 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,6 +139,27 @@ def load_upscalers(): key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) +# None: not loaded, False: failed to load, True: loaded +_spandrel_extra_init_state = None + + +def _init_spandrel_extra_archs() -> None: + """ + Try to initialize `spandrel_extra_archs` (exactly once). + """ + global _spandrel_extra_init_state + if _spandrel_extra_init_state is not None: + return + + try: + import spandrel + import spandrel_extra_arches + spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY) + _spandrel_extra_init_state = True + except Exception: + logger.warning("Failed to load spandrel_extra_arches", exc_info=True) + _spandrel_extra_init_state = False + def load_spandrel_model( path: str | os.PathLike, @@ -148,11 +169,16 @@ def load_spandrel_model( dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: + global _spandrel_extra_init_state + import spandrel + _init_spandrel_extra_archs() + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) - if expected_architecture and model_descriptor.architecture != expected_architecture: + arch = model_descriptor.architecture + if expected_architecture and arch.name != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})", ) half = False if prefer_half: @@ -166,6 +192,6 @@ def load_spandrel_model( model_descriptor.model.eval() logger.debug( "Loaded %s from %s (device=%s, half=%s, dtype=%s)", - model_descriptor, path, device, half, dtype, + arch, path, device, half, dtype, ) return model_descriptor diff --git a/requirements_versions.txt b/requirements_versions.txt index 3037a395b..050b6d1fb 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -23,7 +23,8 @@ pytorch_lightning==1.9.4 resize-right==0.0.2 safetensors==0.4.2 scikit-image==0.21.0 -spandrel==0.1.6 +spandrel==0.3.4 +spandrel-extra-arches==0.1.1 tomesd==0.1.3 torch torchdiffeq==0.2.3