Verify architecture for loaded Spandrel models
This commit is contained in:
parent
c756133541
commit
4ad0c0c0a8
|
@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
return modelloader.load_spandrel_model(filename, device=device)
|
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
|
|
|
@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
|
||||||
filename,
|
filename,
|
||||||
device=self._get_device(),
|
device=self._get_device(),
|
||||||
dtype=devices.dtype,
|
dtype=devices.dtype,
|
||||||
|
expected_architecture="SwinIR",
|
||||||
)
|
)
|
||||||
if getattr(opts, 'SWIN_torch_compile', False):
|
if getattr(opts, 'SWIN_torch_compile', False):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||||
return modelloader.load_spandrel_model(
|
return modelloader.load_spandrel_model(
|
||||||
model_path,
|
model_path,
|
||||||
device=devices.device_codeformer,
|
device=devices.device_codeformer,
|
||||||
|
expected_architecture='CodeFormer',
|
||||||
).model
|
).model
|
||||||
raise ValueError("No codeformer model found")
|
raise ValueError("No codeformer model found")
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
|
||||||
return modelloader.load_spandrel_model(
|
return modelloader.load_spandrel_model(
|
||||||
filename,
|
filename,
|
||||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||||
|
expected_architecture='ESRGAN',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||||
net = modelloader.load_spandrel_model(
|
net = modelloader.load_spandrel_model(
|
||||||
model_path,
|
model_path,
|
||||||
device=self.get_device(),
|
device=self.get_device(),
|
||||||
|
expected_architecture='GFPGAN',
|
||||||
).model
|
).model
|
||||||
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||||
return net
|
return net
|
||||||
|
|
|
@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
|
||||||
return modelloader.load_spandrel_model(
|
return modelloader.load_spandrel_model(
|
||||||
path,
|
path,
|
||||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||||
|
expected_architecture='HAT',
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,6 +6,8 @@ import shutil
|
||||||
import importlib
|
import importlib
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
@ -183,9 +185,18 @@ def load_upscalers():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
|
def load_spandrel_model(
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
device: str | torch.device | None,
|
||||||
|
half: bool = False,
|
||||||
|
dtype: str | None = None,
|
||||||
|
expected_architecture: str | None = None,
|
||||||
|
):
|
||||||
import spandrel
|
import spandrel
|
||||||
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||||
|
if expected_architecture and model.architecture != expected_architecture:
|
||||||
|
raise TypeError(f"Model {path} is not a {expected_architecture} model")
|
||||||
if half:
|
if half:
|
||||||
model = model.model.half()
|
model = model.model.half()
|
||||||
if dtype:
|
if dtype:
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules.upscaler_utils import upscale_with_model
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import cmd_opts, opts
|
|
||||||
from modules import modelloader, errors
|
from modules import modelloader, errors
|
||||||
|
from modules.shared import cmd_opts, opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
|
@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||||
info.local_data_path,
|
info.local_data_path,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||||
|
expected_architecture="RealESRGAN",
|
||||||
)
|
)
|
||||||
return upscale_with_model(
|
return upscale_with_model(
|
||||||
mod,
|
mod,
|
||||||
|
|
Loading…
Reference in New Issue