Verify architecture for loaded Spandrel models

This commit is contained in:
Aarni Koskela 2023-12-30 16:37:03 +02:00
parent c756133541
commit 4ad0c0c0a8
8 changed files with 22 additions and 5 deletions

View File

@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
return modelloader.load_spandrel_model(filename, device=device)
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings():

View File

@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
filename,
device=self._get_device(),
dtype=devices.dtype,
expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
try:

View File

@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")

View File

@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
return modelloader.load_spandrel_model(
filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
expected_architecture='ESRGAN',
)

View File

@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
net = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net

View File

@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
)

View File

@ -6,6 +6,8 @@ import shutil
import importlib
from urllib.parse import urlparse
import torch
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@ -183,9 +185,18 @@ def load_upscalers():
)
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
def load_spandrel_model(
path: str,
*,
device: str | torch.device | None,
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half:
model = model.model.half()
if dtype:

View File

@ -1,9 +1,9 @@
import os
from modules.upscaler_utils import upscale_with_model
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="RealESRGAN",
)
return upscale_with_model(
mod,