From d0982230529c71a29ed3b6d143a5710311953df0 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sat, 2 Mar 2024 01:20:03 -0500 Subject: [PATCH] merge conflict --- caption_cog.py | 196 ++++++++++++++++++----- plugins/__init__.py | 0 plugins/caption_plugins.py | 307 +++++++++++++++++++++++++++++++++++++ 3 files changed, 464 insertions(+), 39 deletions(-) create mode 100644 plugins/__init__.py create mode 100644 plugins/caption_plugins.py diff --git a/caption_cog.py b/caption_cog.py index a4755ce..cbe42e5 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -18,24 +18,84 @@ import os import io import argparse import time -from typing import Generator +import json +import logging +import re +from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal import torch +from torchvision import transforms from PIL import Image import PIL.ImageOps as ImageOps from pynvml import * -from transformers import AutoModelForCausalLM, LlamaTokenizer +from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer from colorama import Fore, Style -SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] +from plugins.caption_plugins import load_prompt_alteration_plugin -def image_generator(image_dir) -> Generator[str, None, None]: - for root, dirs, files in os.walk(image_dir): - for file in files: - if any([file.endswith(ext) for ext in SUPPORTED_EXT]): - yield os.path.join(root, file) +SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] +IMAGE_SIZE: int = 490 +PATCH_SIZE: int = 14 + +def build_conversation_input_ids( + tokenizer: PreTrainedTokenizer, + *, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + images: Optional[List[Image.Image]] = None, + starts_with: Optional[str] = None, + ): + # based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py + image_size: int = IMAGE_SIZE + patch_size: int = PATCH_SIZE + assert images is None or len(images) <= 1, f"not support multi images by now." + history = history or [] + + text = f"Question: {query} Answer: " + text += starts_with if starts_with is not None else "" + + input_ids = [tokenizer.bos_token_id] + token_type_ids = [0] + if images is not None and len(images) == 1: + # vision + transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + images = [transform(images[0])] + vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 + input_ids += [tokenizer.pad_token_id] * vision_token_num + token_type_ids += [1] * vision_token_num + text_ids = tokenizer.encode(text, add_special_tokens=False) + + input_ids += text_ids + token_type_ids += [0] * len(text_ids) + attention_mask = [1] * len(input_ids) + + return { + 'input_ids': torch.tensor(input_ids, dtype=torch.long), + 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), + 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), + 'images': images, + } + +def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]: + if do_recurse: + for root, dirs, files in os.walk(image_dir): + for file in files: + if any(file.endswith(ext) for ext in SUPPORTED_EXT): + yield os.path.join(root, file) + else: + for file in os.listdir(image_dir): + if any(file.endswith(ext) for ext in SUPPORTED_EXT): + yield os.path.join(image_dir, file) def get_gpu_memory_map(): nvmlInit() @@ -44,13 +104,27 @@ def get_gpu_memory_map(): nvmlShutdown() return info.used/1024/1024 +def save_params(args, gen_kwargs): + save_path = os.path.join(args.image_dir, "caption_cog_params.txt") + args_dict = { + "args": vars(args), + "gen_kwargs": gen_kwargs, + } + pretty_print = json.dumps(args_dict, indent=4) + with open(save_path, "w") as f: + f.write(pretty_print) + + def main(args): + prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args) + tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') model = AutoModelForCausalLM.from_pretrained( 'THUDM/cogvlm-chat-hf', torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, - trust_remote_code=True, + trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor + #revision=... # no one is actually doing this load_in_4bit=not args.disable_4bit, ) @@ -61,8 +135,8 @@ def main(args): args.temp = args.temp or 1.0 args.append = args.append or "" - if len(args.append) > 0 and not args.append.startswith(" "): - args.append = " " + args.append + if len(args.append) > 0: + args.append = " " + args.append.strip() gen_kwargs = { "max_length": args.max_length, @@ -80,52 +154,61 @@ def main(args): } if args.max_new_tokens is not None: - print(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") + logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") del gen_kwargs["max_length"] if not do_sample: - print(f"** Using greedy search instead sampling. Generated captions will be deterministic; meaning it will be the same even if you run this program multiple times.") + logging.info(f"** Using greedy sampling") del gen_kwargs["top_k"] del gen_kwargs["top_p"] del gen_kwargs["temperature"] else: - print(f"** Sampling enabled") + logging.info(f"** Sampling enabled") force_words_ids = None if args.force_words is not None: force_words = args.force_words.split(",") if args.force_words is not None else [] - print(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}") + logging.info(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}") force_words_ids = tokenizer(force_words, add_special_tokens=False)["input_ids"] if force_words else [] bad_words_ids = None if args.bad_words is not None: bad_words = args.bad_words.split(",") if args.bad_words is not None else [] - print(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}") + logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}") bad_words_ids = tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else [] - print(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}") + logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}") + + save_params(args, gen_kwargs) total_start_time = time.time() i_processed = 0 - for image_path in image_generator(args.image_dir): + starts_with = args.starts_with.strip() + + for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)): candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt") if args.no_overwrite and os.path.exists(candidate_caption_path): - print(f"Skipping {image_path}, caption already exists.") + logging.warning(f"Skipping {image_path}, caption already exists.") continue - start_time = time.time() + cap_start_time = time.time() image = Image.open(image_path) try: image = image.convert('RGB') image = ImageOps.exif_transpose(image) except Exception as e: - print(f"Non-fatal error processing {image_path}: {e}") + logging.warning(f"Non-fatal error processing {image_path}: {e}") continue + + logging.debug(f" __ Prompt before plugin: {Fore.LIGHTGREEN_EX}{args.prompt}{Style.RESET_ALL}") + prompt = prompt_plugin_fn(image_path, args=args) + logging.debug(f" __ Modified prompt after plugin: {Fore.LIGHTGREEN_EX}{prompt}{Style.RESET_ALL}") + + inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode - inputs = model.build_conversation_input_ids(tokenizer, query=args.prompt, history=[], images=[image]) # chat mode inputs = { 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), @@ -134,27 +217,53 @@ def main(args): } with torch.no_grad(): + #input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + #logging.debug(f"inputs decoded: {input_decoded}") outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) - outputs_without_prompt = outputs[:, inputs['input_ids'].shape[1]:] + + len_inputs = inputs['input_ids'].shape[1] + outputs_without_prompt = outputs[:, len_inputs:] + caption = tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) + if not args.remove_starts_with: + # deal with caption starting with comma, etc + if not re.match(r"^\W", caption): + caption = starts_with + " " + caption + else: + caption = starts_with + caption + caption += args.append - with open(candidate_caption_path, "w", encoding="utf-8") as f: + with open(candidate_caption_path, "w") as f: f.write(caption) vram_gb = get_gpu_memory_map() - elapsed_time = time.time() - start_time - print(f"VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ") - print(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}") + elapsed_time = time.time() - cap_start_time + logging.info(f"n:{i:05}, VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ") + logging.info(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}") i_processed += 1 if i_processed == 0: - print(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)") + logging.info(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)") exit(1) total_elapsed_time = time.time() - total_start_time avg_time = total_elapsed_time / i_processed hh_mm_ss = time.strftime("%H:%M:%S", time.gmtime(total_elapsed_time)) - print(f"** Done captioning {args.image_dir} with prompt '{args.prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image") + logging.info(f"** Done captioning {args.image_dir} with prompt '{prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image") + + +def configure_logging(args: argparse.Namespace): + level = logging.INFO if not args.debug else logging.DEBUG + filemode = "a" if args.append_log else "w" + logging.basicConfig(filename="caption_cog.log", + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + filemode=filemode) + + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter('%(message)s')) + logging.getLogger('').addHandler(console) EXAMPLES = """ex. Basic example: @@ -189,6 +298,7 @@ DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL} if __name__ == "__main__": argparser = argparse.ArgumentParser() + argparser.add_argument("--debug", action="store_true", help="Enable debug logging") argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.") argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling") argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)") @@ -206,26 +316,34 @@ if __name__ == "__main__": argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.") argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.") argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'") + argparser.add_argument("--no_recurse", action="store_true", help="Do not recurse into subdirectories.") + argparser.add_argument("--prompt_plugin", type=str, default=None, help="Function name to modify prompt, edit code to add plugins.") + argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.") + argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") + argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") args = argparser.parse_args() + configure_logging(args) + print(DESCRIPTION) print(EXAMPLES) - if args.top_k is not None or args.top_p is not None or args.temp is not None: - print(f"** Sampling enabled.") - args.sampling = True - args.top_k = args.top_k or 50 - args.top_p = args.top_p or 1.0 - args.temp = args.temp or 1.0 - if args.image_dir is None: - print(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}") + logging.error(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}") exit(1) if not os.path.exists(args.image_dir): - print(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}") + logging.error(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}") exit(1) - print(f"** Running: {args.image_dir} with prompt '{args.prompt}'") + startprint = f"** Running: {args.image_dir} with prompt '{args.prompt}" + if args.starts_with is not None: + startprint += f" {args.starts_with}'" + else: + startprint += "'" + startprint += f" " + if args.append is not None: + startprint += f", and appending: {args.append}" + logging.info(startprint) main(args) diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/caption_plugins.py b/plugins/caption_plugins.py new file mode 100644 index 0000000..f1e0c86 --- /dev/null +++ b/plugins/caption_plugins.py @@ -0,0 +1,307 @@ +from argparse import Namespace +from typing import List +import os +import re +import json +import logging +from colorama import Fore, Style +import importlib, pkgutil + +class TestBase(): + def __init__(self): + self.a = 1 + + def __repr__(self) -> str: + return f"TestBase: {self.a}" + +class TestSub(TestBase): + def __init__(self): + super().__init__() + self.b = 2 + + def __repr__(self) -> str: + return f"TestSub: {self.a}, {self.b}" + +class PromptIdentityPlugin(): + """ + Base class for prompt alternation plugins, useful for captioning, etc. + """ + def __init__(self, description: str="identity", key: str="indentity_plugin", fn: callable=None, args: Namespace=None): + self.description = description + #print(f"PromptIdentityPlugin: __init__ with fn: {fn}") + if fn is None: + fn = self._prompt_identity_from_args + #print(f"{self.__class__}: fn is None, setting to self._prompt_identity_from_args") + self.fn = fn + self._key = key + self.args = args + #print(f"self._key: {self._key}") + + @property + def key(self) -> str: + return self._key + + def _prompt_identity_from_args(self, args: Namespace) -> str: + #print("Wat") + if "prompt" not in args: + raise ValueError(f"prompt is required for prompt_identity_from_args") + #print(f"prompt: {args.prompt}") + #print(f"{type(args)}, type(prompt): {type(args.prompt)}") + return args.prompt + + def __repr__(self) -> str: + return f"Plugin Function: \"{self.key}\" - {self.description}" + + def __str__(self) -> str: + return self.__repr__() + + def __call__(self, image_path, args: Namespace) -> str: + #print(f"Calling {self.key} with image_path: {image_path}, args: {args}") + args.image_path = image_path + return self.fn(args) + + @staticmethod + def _add_hint_to_prompt(hint: str, prompt: str) -> str: + if "\{hint\}" in prompt: + prompt = prompt.replace("\{hint\}", hint) + else: + prompt = f"Hint: {hint}\n{prompt}" + return prompt + +class HintFromFilename(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + super().__init__(key="hint_from_filename", + description="Add a hint to the prompt using the filename of the image (without extension)", + fn=self._from_filename, + args=args) + + def _from_filename(self, args: Namespace) -> str: + image_path = args.get("image_path", "") + filename = os.path.splitext(image_path)[0] + prompt = self._add_hint_to_prompt(filename, prompt) + return prompt + +class RemoveUsingCSV(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + super().__init__(key="remove_using_csv", + description="Removes whole word matches of the csv passed in from the prompt", + fn=self._remove_using_csv, + args=args) + + def _filter_logic(self, prompt: str, filters: List[str]) -> str: + # word boundary filter + pattern = r'\b(?:' + '|'.join([re.escape(word) for word in filters]) + r')\b' + + result = re.sub(pattern, '', prompt) + + # fix up extra space and punctuation + result = re.sub(r'\s{2,}', ' ', result) # Remove extra spaces + result = re.sub(r'\s([,.!?;])', r'\1', result) # Fix punctuation and spaces + + return result.strip() + + def _remove_using_csv(self, args: Namespace) -> str: + prompt = args.prompt + csv = args.csv + if len(csv) == 0: + logging.error(f"** {Fore.RED}Error: csv is required for remove_using_csv{Style.RESET_ALL}") + else: + words = csv.split(",") + for word in words: + prompt = self._filter_logic(prompt, [word]) + return prompt + +class HintFromLeafDirectory(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + super().__init__(key="from_leaf_directory", + description="Adds a hint to the prompt using the leaf directory name (last folder in path)", + fn=self._from_leaf_directory, + args=args) + + def _from_leaf_directory(self, args:Namespace) -> str: + image_path = args.image_path + prompt = args.prompt + leaf_folder_of_image = os.path.basename(os.path.dirname(image_path)) + return self._add_hint_to_prompt(leaf_folder_of_image, prompt) + +class MetadataProvider(): + """ provides and caches metadata""" + def __init__(self): + self._datadict = {} + + def _from_metadata(self, args) -> dict: + image_path = args.get("image_path", "") + prompt = args.get("prompt", "") + metadata = self._get_metadata_dict(image_path) + return f"metadata: {metadata}\n{prompt}" + + def _get_metadata_dict(self, metadata_path: str) -> dict: + if not self.loaded and not metadata_path in self.cache: + metadata_dirname = os.path.dirname(metadata_path) + if not os.path.exists(metadata_path): + logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}") + self._datadict[metadata_path] = {} + with open(metadata_path, "r") as f: + metadata = json.load(f) + self._datadict[metadata_path] = metadata + + return self.dict[metadata_path] + +class FromFolderMetadataJson(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + super().__init__(key="from_folder_metadata", + description="Looks for metadata.json in the folder of the images", + fn=self._from_metadata_json, + args=args) + self.metadata_provider = MetadataProvider() + + def _from_metadata_json(self, args:Namespace) -> dict: + image_path = args.image_path + image_dir = os.path.dirname(image_path) + metadata_json_path = os.path.join(image_dir, "metadata.json") + self.metadata_provider._get_metadata_dict(metadata_json_path) + + return "" + +class TagsFromFolderMetadataJson(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + self.cache = {} + super().__init__(key = "tags_from_metadata_json", + description="Adds tags hint from metadata.json (in the samefolder as the image) to the prompt", + fn=self._tags_from_metadata_json, + args=args) + self.metadata_provider = MetadataProvider() + + def _tags_from_metadata_json(self, args:Namespace) -> str: + image_path = args.image_path + + current_dir = os.path.dirname(image_path) + metadata_json_path = os.path.join(current_dir, "metadata.json") + self.metadata_provider._get_metadata_dict(metadata_json_path).get("tags", []) + + prompt = args.prompt + if len(tags) > 0: + tags = ", ".join(tags) + return self._add_hint_to_prompt(f"tags: {tags}", prompt) + return prompt + +class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin): + def __init__(self, args:Namespace=None): + self.cache = {} + super().__init__(key="title_and_tags_from_metadata_json", + description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt", + fn=self._title_and_tags_from_metadata_json, + args=args) + + def _title_and_tags_from_metadata_json(self, args:Namespace) -> str: + prompt = args.prompt + logging.debug(f" {self.key}: prompt before: {prompt}") + image_path = args.image_path + current_dir = os.path.dirname(image_path) + metadata_json_path = os.path.join(current_dir, "metadata.json") + + if metadata_json_path not in self.cache: + if not os.path.exists(metadata_json_path): + logging.error(f"** {Fore.RED}Error: metadata.json not found in {current_dir}, skippin prompt modification{Style.RESET_ALL}") + return prompt + with open(metadata_json_path, "r") as f: + metadata = json.load(f) + self.cache[metadata_json_path] = metadata + + title = self.cache[metadata_json_path].get("title", "").strip() + hint = f"title: {title}" if len(title) > 0 else "" + + tags = self.cache[metadata_json_path].get("tags", []) + tags = tags.split(",") if isinstance(tags, str) else tags # can be csv or list + if len(tags) > 0: + tags = ", ".join(tags) + hint += f", tags: {tags}" + + prompt = self._add_hint_to_prompt(hint, prompt) + logging.debug(f" {self.key}: prompt after: {prompt}") + return prompt + +class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin): + """ + Adds title and tags hint from global metadata json given by '--metadatafilename' + Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful? + """ + def __init__(self, args:Namespace=None): + self.cache = {} + self.metadata_loaded = False + super().__init__(key="title_and_tags_from_global_metadata_json", + description="Adds title and tags hint from global metadata json given by '--metadatafilename mydata/somefile.json'", + fn=self._title_and_tags_from_global_metadata_json, + args=args) + + def _title_and_tags_from_global_metadata_json(self, image_path: str, **kwargs) -> str: + prompt = kwargs.get("prompt", "") + metadata_json_path = kwargs.get("metadata_json_path", "") + + if not self.metadata_loaded: # kinda sloppy but avoids me having to think about reworking init args + if not os.path.exists(metadata_json_path): + raise FileNotFoundError(f"metadata.json not found in {metadata_json_path}") + with open(metadata_json_path, "r") as f: + metadata = json.load(f) + self.cache[metadata_json_path] = metadata + self.metadata_loaded = True + + title = self.cache[metadata_json_path].get("title", "") + hint = f"title: {title}" + + tags = self.cache[metadata_json_path].get("tags", []) + if len(tags) > 0: + tags = ", ".join(tags) + hint += f", tags: {tags}" + + return self._add_hint_to_prompt(hint, prompt) + +def is_subclass_of_subclass(attribute, base_class, recursion_depth=5): + if attribute.__module__ == base_class.__module__: + if issubclass(attribute, base_class) and attribute is not base_class: + return True + + if recursion_depth == 0: + return False + recursion_depth -= 1 + for base in attribute.__bases__: + if is_subclass_of_subclass(base, base_class, recursion_depth): + return True + return False + +def get_prompt_alteration_plugin_list() -> list: + plugins = [] + + for finder, name, ispkg in pkgutil.iter_modules(["plugins"]): + plugins_module_name = f"plugins.{name}" + + if plugins_module_name == "plugins.caption_plugins": + module = importlib.import_module(plugins_module_name) + + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if isinstance(attribute, type) \ + and attribute.__module__ == module.__name__ \ + and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \ + and attribute is not PromptIdentityPlugin: + + plugins.append(attribute) + #print(f"done checking plugins_module_name: {plugins_module_name}") + return plugins + +def load_prompt_alteration_plugin(plugin_key: str, args) -> callable: + if plugin_key is not None: + prompt_alteration_plugins = get_prompt_alteration_plugin_list() + + for prompt_plugin_cls in prompt_alteration_plugins: + plugin_instance = prompt_plugin_cls(args) + #print(f"prompt_plugin_cls: {prompt_plugin_cls}") + #print(f"prompt_plugin_cls.key: {prompt_plugin_cls.key}") + if plugin_key == plugin_instance.key: + logging.info(f" **** Found plugin: {plugin_instance.key}") + return plugin_instance + raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins") + else: + logging.info(f"No plugin specified") + return PromptIdentityPlugin(args=args)