helpful error message when trying to load 2.0 without config
failing to load model weights from settings won't break generation for currently loaded model anymore
This commit is contained in:
parent
7e549468b3
commit
02d7abf514
|
@ -2,9 +2,30 @@ import sys
|
|||
import traceback
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
for line in lines:
|
||||
print(line, file=sys.stderr)
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task):
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
except Exception as e:
|
||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
display(task, e)
|
||||
|
|
|
@ -278,6 +278,7 @@ def enable_midas_autodownload():
|
|||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
|
|||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
|
@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
|
|||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
|
@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("Weights loaded.")
|
||||
|
||||
return sd_model
|
||||
|
|
|
@ -14,7 +14,7 @@ import modules.interrogate
|
|||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
|
@ -494,7 +494,12 @@ class Options:
|
|||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
self.data_labels[key].onchange()
|
||||
try:
|
||||
self.data_labels[key].onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
|
12
webui.py
12
webui.py
|
@ -9,7 +9,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
from modules import import_hook
|
||||
from modules import import_hook, errors
|
||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||
from modules.paths import script_path
|
||||
|
||||
|
@ -61,7 +61,15 @@ def initialize():
|
|||
modelloader.load_upscalers()
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
modules.sd_models.load_model()
|
||||
|
||||
try:
|
||||
modules.sd_models.load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
|
|
Loading…
Reference in New Issue