Merge branch 'dag' into dev

This commit is contained in:
AUTOMATIC1111 2023-11-20 14:50:01 +03:00
commit 1463cea949
2 changed files with 146 additions and 151 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import configparser
import functools
import os
import threading
import re
@ -8,7 +9,6 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
os.makedirs(extensions_dir, exist_ok=True)
@ -22,13 +22,56 @@ def active():
return [x for x in extensions if x.enabled]
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
canonical_name: str
requires: list
def __init__(self, path, canonical_name):
self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename)
if os.path.isfile(filepath):
try:
self.config.read(filepath)
except Exception:
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension")
def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file,
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
reads more requirements from [extra_section] if specified."""
x = self.config.get(section, field, fallback='')
if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower())
def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
if not text:
return []
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
class Extension:
lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
metadata: ExtensionMetadata
def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name
self.canonical_name = canonical_name or name.lower()
self.path = path
self.enabled = enabled
self.status = ''
@ -40,18 +83,8 @@ class Extension:
self.branch = None
self.remote = None
self.have_info_from_repo = False
@functools.cached_property
def metadata(self):
if os.path.isfile(os.path.join(self.path, "metadata.ini")):
try:
config = configparser.ConfigParser()
config.read(os.path.join(self.path, "metadata.ini"))
return config
except Exception:
errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
exc_info=True)
return None
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}
@ -162,7 +195,7 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_dependency_map = {}
loaded_extensions = {}
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
@ -175,55 +208,30 @@ def list_extensions():
continue
canonical_name = extension_dirname
requires = None
if os.path.isfile(os.path.join(path, "metadata.ini")):
try:
config = configparser.ConfigParser()
config.read(os.path.join(path, "metadata.ini"))
canonical_name = config.get("Extension", "Name", fallback=canonical_name)
requires = config.get("Extension", "Requires", fallback=None)
except Exception:
errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
f"Will load regardless.", exc_info=True)
canonical_name = canonical_name.lower().strip()
metadata = ExtensionMetadata(path, canonical_name)
# check for duplicated canonical names
if canonical_name in extension_dependency_map:
errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
f"The current loading extension will be discarded.", exc_info=False)
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
if already_loaded_extension is not None:
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
continue
# both "," and " " are accepted as separator
requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
extension_dependency_map[canonical_name] = {
"dirname": extension_dirname,
"path": path,
"requires": requires,
}
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
loaded_extensions[canonical_name] = extension
# check for requirements
for (_, extension_data) in extension_dependency_map.items():
dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
requirement_met = True
for req in requires:
if req not in extension_dependency_map:
errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
dep_dirname = extension_dependency_map[req]['dirname']
if dep_dirname in shared.opts.disabled_extensions:
errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
for extension in extensions:
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=dirname, path=path,
enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
is_builtin=is_builtin)
extensions.append(extension)
if not extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
extensions: list[Extension] = []

View File

@ -2,7 +2,6 @@ import os
import re
import sys
import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from dataclasses import dataclass
@ -312,27 +311,57 @@ scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass
class ScriptWithDependencies:
script_canonical_name: str
file: ScriptFile
requires: list
load_before: list
load_after: list
def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = []
script_dependency_map = {}
scripts = {}
loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
# build script dependency map
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
script_dependency_map[filename] = {
"extension": None,
"extension_dirname": None,
"script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
"requires": [],
"load_before": [],
"load_after": [],
}
if os.path.splitext(filename)[1].lower() != extension:
continue
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions:
for ext in extensions.active():
@ -341,96 +370,54 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True):
if not os.path.isfile(extension_script.path):
continue
script_canonical_name = ext.canonical_name + "/" + extension_script.filename
if ext.is_builtin:
script_canonical_name = "builtin/" + script_canonical_name
script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
relative_path = scriptdirname + "/" + extension_script.filename
requires = ''
load_before = ''
load_after = ''
script = ScriptWithDependencies(
script_canonical_name=script_canonical_name,
file=extension_script,
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
)
if ext.metadata is not None:
requires = ext.metadata.get(relative_path, "Requires", fallback='')
load_before = ext.metadata.get(relative_path, "Before", fallback='')
load_after = ext.metadata.get(relative_path, "After", fallback='')
scripts[script_canonical_name] = script
loaded_extensions_scripts[ext.canonical_name].append(script)
# propagate directory level metadata
requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')
requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []
script_dependency_map[script_canonical_name] = {
"extension": ext.canonical_name,
"extension_dirname": ext.name,
"script_file": extension_script,
"requires": requires,
"load_before": load_before,
"load_after": load_after,
}
# resolve dependencies
loaded_extensions = set()
for ext in extensions.active():
loaded_extensions.add(ext.canonical_name)
for script_canonical_name, script_data in script_dependency_map.items():
for script_canonical_name, script in scripts.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before_script in script_data['load_before']:
for load_before in script.load_before:
# if this requires an individual script to be loaded before
if load_before_script in script_dependency_map:
script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
elif load_before_script in loaded_extensions:
for _, script_data2 in script_dependency_map.items():
if script_data2['extension'] == load_before_script:
script_data2['load_after'].append(script_canonical_name)
break
other_script = scripts.get(load_before)
if other_script:
other_script.load_after.append(script_canonical_name)
# resolve extension name in load_after lists
for load_after_script in list(script_data['load_after']):
if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
script_data['load_after'].remove(load_after_script)
for script_canonical_name2, script_data2 in script_dependency_map.items():
if script_data2['extension'] == load_after_script:
script_data['load_after'].append(script_canonical_name2)
break
# if this requires an extension
other_extension_scripts = loaded_extensions_scripts.get(load_before)
if other_extension_scripts:
for other_script in other_extension_scripts:
other_script.load_after.append(script_canonical_name)
# build the DAG
sorter = TopologicalSorter()
for script_canonical_name, script_data in script_dependency_map.items():
requirement_met = True
for required_script in script_data['requires']:
# if this requires an individual script to be loaded
if required_script not in script_dependency_map and required_script not in loaded_extensions:
errors.report(f"Script \"{script_canonical_name}\" "
f"requires \"{required_script}\" to "
f"be loaded, but it is not. Skipping.",
exc_info=False)
requirement_met = False
break
if not requirement_met:
continue
# if After mentions an extension, remove it and instead add all of its scripts
for load_after in list(script.load_after):
if load_after not in scripts and load_after in loaded_extensions_scripts:
script.load_after.remove(load_after)
sorter.add(script_canonical_name, *script_data['load_after'])
for other_script in loaded_extensions_scripts.get(load_after, []):
script.load_after.append(other_script.script_canonical_name)
# sort the scripts
try:
ordered_script = sorter.static_order()
except CycleError:
errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
ordered_script = script_dependency_map.keys()
dependencies = {}
for script_canonical_name in ordered_script:
script_data = script_dependency_map[script_canonical_name]
scripts_list.append(script_data['script_file'])
for script_canonical_name, script in scripts.items():
for required_script in script.requires:
if required_script not in scripts and required_script not in loaded_extensions:
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
dependencies[script_canonical_name] = script.load_after
ordered_scripts = topological_sort(dependencies)
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list