Upscaler.load_model: don't return None, just use exceptions
This commit is contained in:
parent
e3a973a68d
commit
bf67a5dcf4
|
@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler):
|
||||||
|
|
||||||
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||||
|
|
||||||
try:
|
return LDSR(model, yaml)
|
||||||
return LDSR(model, yaml)
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error importing LDSR", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
ldsr = self.load_model(path)
|
try:
|
||||||
if ldsr is None:
|
ldsr = self.load_model(path)
|
||||||
print("NO LDSR!")
|
except Exception:
|
||||||
|
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os.path
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
@ -8,7 +7,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, modelloader, script_callbacks, errors
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet
|
||||||
|
|
||||||
from modules.modelloader import load_file_from_url
|
from modules.modelloader import load_file_from_url
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = self.load_model(selected_file)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_file)
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
except Exception as e:
|
||||||
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
|
@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
|
|
||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for _, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -7,8 +7,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler):
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img, model_file):
|
||||||
model = self.load_model(model_file)
|
try:
|
||||||
if model is None:
|
model = self.load_model(model_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
model = model.to(device_swinir, dtype=devices.dtype)
|
||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
|
@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
|
||||||
return None
|
|
||||||
if filename.endswith(".v2.pth"):
|
if filename.endswith(".v2.pth"):
|
||||||
model = net2(
|
model = Swin2SR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
window_size=8,
|
window_size=8,
|
||||||
img_range=1.0,
|
img_range=1.0,
|
||||||
depths=[6, 6, 6, 6, 6, 6],
|
depths=[6, 6, 6, 6, 6, 6],
|
||||||
embed_dim=180,
|
embed_dim=180,
|
||||||
num_heads=[6, 6, 6, 6, 6, 6],
|
num_heads=[6, 6, 6, 6, 6, 6],
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
upsampler="nearest+conv",
|
upsampler="nearest+conv",
|
||||||
resi_connection="1conv",
|
resi_connection="1conv",
|
||||||
)
|
)
|
||||||
params = None
|
params = None
|
||||||
else:
|
else:
|
||||||
model = net(
|
model = SwinIR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,9 +6,8 @@ from PIL import Image
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
def mod2normal(state_dict):
|
||||||
|
@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler):
|
||||||
self.scalers.append(scaler_data)
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
def do_upscale(self, img, selected_model):
|
def do_upscale(self, img, selected_model):
|
||||||
model = self.load_model(selected_model)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model.to(devices.device_esrgan)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
img = esrgan_upscale(model, img)
|
||||||
|
@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"Unable to load {self.model_path} from {filename}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts
|
||||||
from modules import modelloader, errors
|
from modules import modelloader, errors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.name = "RealESRGAN"
|
self.name = "RealESRGAN"
|
||||||
|
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
||||||
if not self.enable:
|
if not self.enable:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
info = self.load_model(path)
|
try:
|
||||||
if not os.path.exists(info.local_data_path):
|
info = self.load_model(path)
|
||||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
except Exception:
|
||||||
|
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
|
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
try:
|
for scaler in self.scalers:
|
||||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
if scaler.data_path == path:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
if info is None:
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
print(f"Unable to find model info: {path}")
|
scaler.data_path,
|
||||||
return None
|
model_dir=self.model_download_path,
|
||||||
|
)
|
||||||
if info.local_data_path.startswith("http"):
|
if not os.path.exists(scaler.local_data_path):
|
||||||
info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path)
|
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||||
|
return scaler
|
||||||
return info
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
except Exception:
|
|
||||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
return get_realesrgan_models(self)
|
return get_realesrgan_models(self)
|
||||||
|
|
Loading…
Reference in New Issue