From 2f55d669a26ca2c785b92656b8f32b56839f3750 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 10 Mar 2024 15:14:04 +0300 Subject: [PATCH] add support for specifying callback order in metadata --- modules/extensions.py | 24 ++++++++++++++++++++++++ modules/script_callbacks.py | 34 +++++++++++++++++++++++++++++++++- modules/scripts.py | 27 +++------------------------ modules/util.py | 24 ++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 25 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index ab835d3f2..1f620ff16 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,6 +1,7 @@ from __future__ import annotations import configparser +import dataclasses import os import threading import re @@ -22,6 +23,13 @@ def active(): return [x for x in extensions if x.enabled] +@dataclasses.dataclass +class CallbackOrderInfo: + name: str + before: list + after: list + + class ExtensionMetadata: filename = "metadata.ini" config: configparser.ConfigParser @@ -65,6 +73,22 @@ class ExtensionMetadata: # both "," and " " are accepted as separator return [x for x in re.split(r"[,\s]+", text.strip()) if x] + def list_callback_order_instructions(self): + for section in self.config.sections(): + if not section.startswith("callbacks/"): + continue + + callback_name = section[10:] + + if not callback_name.startswith(self.canonical_name): + errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}") + continue + + before = self.parse_list(self.config.get(section, 'Before', fallback='')) + after = self.parse_list(self.config.get(section, 'After', fallback='')) + + yield CallbackOrderInfo(callback_name, before, after) + class Extension: lock = threading.Lock() diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a0ecf5a53..4d80b433b 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -8,7 +8,7 @@ from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks -from modules import errors, timer, extensions, shared +from modules import errors, timer, extensions, shared, util def report_exception(c, job): @@ -149,6 +149,38 @@ def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True): callbacks = unordered_callbacks.copy() + callback_lookup = {x.name: x for x in callbacks} + dependencies = {} + + order_instructions = {} + for extension in extensions.extensions: + for order_instruction in extension.metadata.list_callback_order_instructions(): + if order_instruction.name in callback_lookup: + if order_instruction.name not in order_instructions: + order_instructions[order_instruction.name] = [] + + order_instructions[order_instruction.name].append(order_instruction) + + if order_instructions: + for callback in callbacks: + dependencies[callback.name] = [] + + for callback in callbacks: + for order_instruction in order_instructions.get(callback.name, []): + for after in order_instruction.after: + if after not in callback_lookup: + continue + + dependencies[callback.name].append(after) + + for before in order_instruction.before: + if before not in callback_lookup: + continue + + dependencies[before].append(callback.name) + + sorted_names = util.topological_sort(dependencies) + callbacks = [callback_lookup[x] for x in sorted_names] if enable_user_sort: for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])): diff --git a/modules/scripts.py b/modules/scripts.py index e1a435827..20710b37d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,9 @@ from dataclasses import dataclass import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util + +topological_sort = util.topological_sort AlwaysVisible = object() @@ -368,29 +370,6 @@ scripts_data = [] postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) -def topological_sort(dependencies): - """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. - Ignores errors relating to missing dependeencies or circular dependencies - """ - - visited = {} - result = [] - - def inner(name): - visited[name] = True - - for dep in dependencies.get(name, []): - if dep in dependencies and dep not in visited: - inner(dep) - - result.append(name) - - for depname in dependencies: - if depname not in visited: - inner(depname) - - return result - @dataclass class ScriptWithDependencies: diff --git a/modules/util.py b/modules/util.py index 8d1aea44f..48268f04c 100644 --- a/modules/util.py +++ b/modules/util.py @@ -136,3 +136,27 @@ class MassFileLister: def reset(self): """Clear the cache of all directories.""" self.cached_dirs.clear() + + +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependeencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result