From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: [PATCH] helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 24 +++++++++++++++++------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 372dc51a0..a668c0147 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc2..6846b74ab 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a939..7588c47b4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a978..13375e71a 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)