Verify architecture for loaded Spandrel models
This commit is contained in:
parent
c756133541
commit
4ad0c0c0a8
|
@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||
else:
|
||||
filename = path
|
||||
return modelloader.load_spandrel_model(filename, device=device)
|
||||
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
|
|
|
@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
|
|||
filename,
|
||||
device=self._get_device(),
|
||||
dtype=devices.dtype,
|
||||
expected_architecture="SwinIR",
|
||||
)
|
||||
if getattr(opts, 'SWIN_torch_compile', False):
|
||||
try:
|
||||
|
|
|
@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
|||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
expected_architecture='CodeFormer',
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
|
|||
return modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||
expected_architecture='ESRGAN',
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
|||
net = modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||
return net
|
||||
|
|
|
@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
|
|||
return modelloader.load_spandrel_model(
|
||||
path,
|
||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||
expected_architecture='HAT',
|
||||
)
|
||||
|
|
|
@ -6,6 +6,8 @@ import shutil
|
|||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
|
@ -183,9 +185,18 @@ def load_upscalers():
|
|||
)
|
||||
|
||||
|
||||
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
|
||||
def load_spandrel_model(
|
||||
path: str,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
half: bool = False,
|
||||
dtype: str | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
):
|
||||
import spandrel
|
||||
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||
if expected_architecture and model.architecture != expected_architecture:
|
||||
raise TypeError(f"Model {path} is not a {expected_architecture} model")
|
||||
if half:
|
||||
model = model.model.half()
|
||||
if dtype:
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
|
@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
info.local_data_path,
|
||||
device=self.device,
|
||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="RealESRGAN",
|
||||
)
|
||||
return upscale_with_model(
|
||||
mod,
|
||||
|
|
Loading…
Reference in New Issue