diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6ea58d616..d8aa6f008 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -3,6 +3,8 @@ import traceback from collections import namedtuple import inspect +from fastapi import FastAPI +from gradio import Blocks def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) @@ -25,6 +27,7 @@ class ImageSaveParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) +callbacks_app_started = [] callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -40,6 +43,14 @@ def clear_callbacks(): callbacks_image_saved.clear() +def app_started_callback(demo: Blocks, app: FastAPI): + for c in callbacks_app_started: + try: + c.callback(demo, app) + except Exception: + report_exception(c, 'app_started_callback') + + def model_loaded_callback(sd_model): for c in callbacks_model_loaded: try: @@ -91,6 +102,12 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) +def on_app_started(callback): + """register a function to be called when the webui started, the gradio `Block` component and + fastapi `FastAPI` object are passed as the arguments""" + add_callback(callbacks_app_started, callback) + + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" diff --git a/webui.py b/webui.py index 6ff95dc47..8d8479f1d 100644 --- a/webui.py +++ b/webui.py @@ -23,6 +23,7 @@ import modules.sd_hijack import modules.sd_models import modules.shared as shared import modules.txt2img +import modules.script_callbacks import modules.ui from modules import devices @@ -140,6 +141,8 @@ def webui(): if launch_api: create_api(app) + modules.script_callbacks.app_started_callback(demo, app) + wait_on_server(demo) sd_samplers.set_samplers()