308 lines
12 KiB
Python
308 lines
12 KiB
Python
from argparse import Namespace
|
|
from typing import List
|
|
import os
|
|
import re
|
|
import json
|
|
import logging
|
|
from colorama import Fore, Style
|
|
import importlib, pkgutil
|
|
|
|
class TestBase():
|
|
def __init__(self):
|
|
self.a = 1
|
|
|
|
def __repr__(self) -> str:
|
|
return f"TestBase: {self.a}"
|
|
|
|
class TestSub(TestBase):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.b = 2
|
|
|
|
def __repr__(self) -> str:
|
|
return f"TestSub: {self.a}, {self.b}"
|
|
|
|
class PromptIdentityPlugin():
|
|
"""
|
|
Base class for prompt alternation plugins, useful for captioning, etc.
|
|
"""
|
|
def __init__(self, description: str="identity", key: str="indentity_plugin", fn: callable=None, args: Namespace=None):
|
|
self.description = description
|
|
#print(f"PromptIdentityPlugin: __init__ with fn: {fn}")
|
|
if fn is None:
|
|
fn = self._prompt_identity_from_args
|
|
#print(f"{self.__class__}: fn is None, setting to self._prompt_identity_from_args")
|
|
self.fn = fn
|
|
self._key = key
|
|
self.args = args
|
|
#print(f"self._key: {self._key}")
|
|
|
|
@property
|
|
def key(self) -> str:
|
|
return self._key
|
|
|
|
def _prompt_identity_from_args(self, args: Namespace) -> str:
|
|
#print("Wat")
|
|
if "prompt" not in args:
|
|
raise ValueError(f"prompt is required for prompt_identity_from_args")
|
|
#print(f"prompt: {args.prompt}")
|
|
#print(f"{type(args)}, type(prompt): {type(args.prompt)}")
|
|
return args.prompt
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Plugin Function: \"{self.key}\" - {self.description}"
|
|
|
|
def __str__(self) -> str:
|
|
return self.__repr__()
|
|
|
|
def __call__(self, image_path, args: Namespace) -> str:
|
|
#print(f"Calling {self.key} with image_path: {image_path}, args: {args}")
|
|
args.image_path = image_path
|
|
return self.fn(args)
|
|
|
|
@staticmethod
|
|
def _add_hint_to_prompt(hint: str, prompt: str) -> str:
|
|
if "\{hint\}" in prompt:
|
|
prompt = prompt.replace("\{hint\}", hint)
|
|
else:
|
|
prompt = f"Hint: {hint}\n{prompt}"
|
|
return prompt
|
|
|
|
class HintFromFilename(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
super().__init__(key="hint_from_filename",
|
|
description="Add a hint to the prompt using the filename of the image (without extension)",
|
|
fn=self._from_filename,
|
|
args=args)
|
|
|
|
def _from_filename(self, args: Namespace) -> str:
|
|
image_path = args.get("image_path", "")
|
|
filename = os.path.splitext(image_path)[0]
|
|
prompt = self._add_hint_to_prompt(filename, prompt)
|
|
return prompt
|
|
|
|
class RemoveUsingCSV(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
super().__init__(key="remove_using_csv",
|
|
description="Removes whole word matches of the csv passed in from the prompt",
|
|
fn=self._remove_using_csv,
|
|
args=args)
|
|
|
|
def _filter_logic(self, prompt: str, filters: List[str]) -> str:
|
|
# word boundary filter
|
|
pattern = r'\b(?:' + '|'.join([re.escape(word) for word in filters]) + r')\b'
|
|
|
|
result = re.sub(pattern, '', prompt)
|
|
|
|
# fix up extra space and punctuation
|
|
result = re.sub(r'\s{2,}', ' ', result) # Remove extra spaces
|
|
result = re.sub(r'\s([,.!?;])', r'\1', result) # Fix punctuation and spaces
|
|
|
|
return result.strip()
|
|
|
|
def _remove_using_csv(self, args: Namespace) -> str:
|
|
prompt = args.prompt
|
|
csv = args.csv
|
|
if len(csv) == 0:
|
|
logging.error(f"** {Fore.RED}Error: csv is required for remove_using_csv{Style.RESET_ALL}")
|
|
else:
|
|
words = csv.split(",")
|
|
for word in words:
|
|
prompt = self._filter_logic(prompt, [word])
|
|
return prompt
|
|
|
|
class HintFromLeafDirectory(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
super().__init__(key="from_leaf_directory",
|
|
description="Adds a hint to the prompt using the leaf directory name (last folder in path)",
|
|
fn=self._from_leaf_directory,
|
|
args=args)
|
|
|
|
def _from_leaf_directory(self, args:Namespace) -> str:
|
|
image_path = args.image_path
|
|
prompt = args.prompt
|
|
leaf_folder_of_image = os.path.basename(os.path.dirname(image_path))
|
|
return self._add_hint_to_prompt(leaf_folder_of_image, prompt)
|
|
|
|
class MetadataProvider():
|
|
""" provides and caches metadata"""
|
|
def __init__(self):
|
|
self._datadict = {}
|
|
|
|
def _from_metadata(self, args) -> dict:
|
|
image_path = args.get("image_path", "")
|
|
prompt = args.get("prompt", "")
|
|
metadata = self._get_metadata_dict(image_path)
|
|
return f"metadata: {metadata}\n{prompt}"
|
|
|
|
def _get_metadata_dict(self, metadata_path: str) -> dict:
|
|
if not self.loaded and not metadata_path in self.cache:
|
|
metadata_dirname = os.path.dirname(metadata_path)
|
|
if not os.path.exists(metadata_path):
|
|
logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}")
|
|
self._datadict[metadata_path] = {}
|
|
with open(metadata_path, "r") as f:
|
|
metadata = json.load(f)
|
|
self._datadict[metadata_path] = metadata
|
|
|
|
return self.dict[metadata_path]
|
|
|
|
class FromFolderMetadataJson(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
super().__init__(key="from_folder_metadata",
|
|
description="Looks for metadata.json in the folder of the images",
|
|
fn=self._from_metadata_json,
|
|
args=args)
|
|
self.metadata_provider = MetadataProvider()
|
|
|
|
def _from_metadata_json(self, args:Namespace) -> dict:
|
|
image_path = args.image_path
|
|
image_dir = os.path.dirname(image_path)
|
|
metadata_json_path = os.path.join(image_dir, "metadata.json")
|
|
self.metadata_provider._get_metadata_dict(metadata_json_path)
|
|
|
|
return ""
|
|
|
|
class TagsFromFolderMetadataJson(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
self.cache = {}
|
|
super().__init__(key = "tags_from_metadata_json",
|
|
description="Adds tags hint from metadata.json (in the samefolder as the image) to the prompt",
|
|
fn=self._tags_from_metadata_json,
|
|
args=args)
|
|
self.metadata_provider = MetadataProvider()
|
|
|
|
def _tags_from_metadata_json(self, args:Namespace) -> str:
|
|
image_path = args.image_path
|
|
|
|
current_dir = os.path.dirname(image_path)
|
|
metadata_json_path = os.path.join(current_dir, "metadata.json")
|
|
self.metadata_provider._get_metadata_dict(metadata_json_path).get("tags", [])
|
|
|
|
prompt = args.prompt
|
|
if len(tags) > 0:
|
|
tags = ", ".join(tags)
|
|
return self._add_hint_to_prompt(f"tags: {tags}", prompt)
|
|
return prompt
|
|
|
|
class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
|
|
def __init__(self, args:Namespace=None):
|
|
self.cache = {}
|
|
super().__init__(key="title_and_tags_from_metadata_json",
|
|
description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt",
|
|
fn=self._title_and_tags_from_metadata_json,
|
|
args=args)
|
|
|
|
def _title_and_tags_from_metadata_json(self, args:Namespace) -> str:
|
|
prompt = args.prompt
|
|
logging.debug(f" {self.key}: prompt before: {prompt}")
|
|
image_path = args.image_path
|
|
current_dir = os.path.dirname(image_path)
|
|
metadata_json_path = os.path.join(current_dir, "metadata.json")
|
|
|
|
if metadata_json_path not in self.cache:
|
|
if not os.path.exists(metadata_json_path):
|
|
logging.error(f"** {Fore.RED}Error: metadata.json not found in {current_dir}, skippin prompt modification{Style.RESET_ALL}")
|
|
return prompt
|
|
with open(metadata_json_path, "r") as f:
|
|
metadata = json.load(f)
|
|
self.cache[metadata_json_path] = metadata
|
|
|
|
title = self.cache[metadata_json_path].get("title", "").strip()
|
|
hint = f"title: {title}" if len(title) > 0 else ""
|
|
|
|
tags = self.cache[metadata_json_path].get("tags", [])
|
|
tags = tags.split(",") if isinstance(tags, str) else tags # can be csv or list
|
|
if len(tags) > 0:
|
|
tags = ", ".join(tags)
|
|
hint += f", tags: {tags}"
|
|
|
|
prompt = self._add_hint_to_prompt(hint, prompt)
|
|
logging.debug(f" {self.key}: prompt after: {prompt}")
|
|
return prompt
|
|
|
|
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin):
|
|
"""
|
|
Adds title and tags hint from global metadata json given by '--metadatafilename'
|
|
Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful?
|
|
"""
|
|
def __init__(self, args:Namespace=None):
|
|
self.cache = {}
|
|
self.metadata_loaded = False
|
|
super().__init__(key="title_and_tags_from_global_metadata_json",
|
|
description="Adds title and tags hint from global metadata json given by '--metadatafilename mydata/somefile.json'",
|
|
fn=self._title_and_tags_from_global_metadata_json,
|
|
args=args)
|
|
|
|
def _title_and_tags_from_global_metadata_json(self, image_path: str, **kwargs) -> str:
|
|
prompt = kwargs.get("prompt", "")
|
|
metadata_json_path = kwargs.get("metadata_json_path", "")
|
|
|
|
if not self.metadata_loaded: # kinda sloppy but avoids me having to think about reworking init args
|
|
if not os.path.exists(metadata_json_path):
|
|
raise FileNotFoundError(f"metadata.json not found in {metadata_json_path}")
|
|
with open(metadata_json_path, "r") as f:
|
|
metadata = json.load(f)
|
|
self.cache[metadata_json_path] = metadata
|
|
self.metadata_loaded = True
|
|
|
|
title = self.cache[metadata_json_path].get("title", "")
|
|
hint = f"title: {title}"
|
|
|
|
tags = self.cache[metadata_json_path].get("tags", [])
|
|
if len(tags) > 0:
|
|
tags = ", ".join(tags)
|
|
hint += f", tags: {tags}"
|
|
|
|
return self._add_hint_to_prompt(hint, prompt)
|
|
|
|
def is_subclass_of_subclass(attribute, base_class, recursion_depth=5):
|
|
if attribute.__module__ == base_class.__module__:
|
|
if issubclass(attribute, base_class) and attribute is not base_class:
|
|
return True
|
|
|
|
if recursion_depth == 0:
|
|
return False
|
|
recursion_depth -= 1
|
|
for base in attribute.__bases__:
|
|
if is_subclass_of_subclass(base, base_class, recursion_depth):
|
|
return True
|
|
return False
|
|
|
|
def get_prompt_alteration_plugin_list() -> list:
|
|
plugins = []
|
|
|
|
for finder, name, ispkg in pkgutil.iter_modules(["plugins"]):
|
|
plugins_module_name = f"plugins.{name}"
|
|
|
|
if plugins_module_name == "plugins.caption_plugins":
|
|
module = importlib.import_module(plugins_module_name)
|
|
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
|
|
if isinstance(attribute, type) \
|
|
and attribute.__module__ == module.__name__ \
|
|
and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \
|
|
and attribute is not PromptIdentityPlugin:
|
|
|
|
plugins.append(attribute)
|
|
#print(f"done checking plugins_module_name: {plugins_module_name}")
|
|
return plugins
|
|
|
|
def load_prompt_alteration_plugin(plugin_key: str, args) -> callable:
|
|
if plugin_key is not None:
|
|
prompt_alteration_plugins = get_prompt_alteration_plugin_list()
|
|
|
|
for prompt_plugin_cls in prompt_alteration_plugins:
|
|
plugin_instance = prompt_plugin_cls(args)
|
|
#print(f"prompt_plugin_cls: {prompt_plugin_cls}")
|
|
#print(f"prompt_plugin_cls.key: {prompt_plugin_cls.key}")
|
|
if plugin_key == plugin_instance.key:
|
|
logging.info(f" **** Found plugin: {plugin_instance.key}")
|
|
return plugin_instance
|
|
raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins")
|
|
else:
|
|
logging.info(f"No plugin specified")
|
|
return PromptIdentityPlugin(args=args)
|