merge conflict
This commit is contained in:
parent
be2cec7d3b
commit
d098223052
196
caption_cog.py
196
caption_cog.py
|
@ -18,24 +18,84 @@ import os
|
|||
import io
|
||||
import argparse
|
||||
import time
|
||||
from typing import Generator
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from PIL import Image
|
||||
import PIL.ImageOps as ImageOps
|
||||
from pynvml import *
|
||||
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer
|
||||
from colorama import Fore, Style
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
from plugins.caption_plugins import load_prompt_alteration_plugin
|
||||
|
||||
def image_generator(image_dir) -> Generator[str, None, None]:
|
||||
for root, dirs, files in os.walk(image_dir):
|
||||
for file in files:
|
||||
if any([file.endswith(ext) for ext in SUPPORTED_EXT]):
|
||||
yield os.path.join(root, file)
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
IMAGE_SIZE: int = 490
|
||||
PATCH_SIZE: int = 14
|
||||
|
||||
def build_conversation_input_ids(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
*,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
starts_with: Optional[str] = None,
|
||||
):
|
||||
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
|
||||
image_size: int = IMAGE_SIZE
|
||||
patch_size: int = PATCH_SIZE
|
||||
assert images is None or len(images) <= 1, f"not support multi images by now."
|
||||
history = history or []
|
||||
|
||||
text = f"Question: {query} Answer: "
|
||||
text += starts_with if starts_with is not None else ""
|
||||
|
||||
input_ids = [tokenizer.bos_token_id]
|
||||
token_type_ids = [0]
|
||||
if images is not None and len(images) == 1:
|
||||
# vision
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
images = [transform(images[0])]
|
||||
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
||||
input_ids += [tokenizer.pad_token_id] * vision_token_num
|
||||
token_type_ids += [1] * vision_token_num
|
||||
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
input_ids += text_ids
|
||||
token_type_ids += [0] * len(text_ids)
|
||||
attention_mask = [1] * len(input_ids)
|
||||
|
||||
return {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
||||
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
|
||||
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
|
||||
'images': images,
|
||||
}
|
||||
|
||||
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
|
||||
if do_recurse:
|
||||
for root, dirs, files in os.walk(image_dir):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(root, file)
|
||||
else:
|
||||
for file in os.listdir(image_dir):
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(image_dir, file)
|
||||
|
||||
def get_gpu_memory_map():
|
||||
nvmlInit()
|
||||
|
@ -44,13 +104,27 @@ def get_gpu_memory_map():
|
|||
nvmlShutdown()
|
||||
return info.used/1024/1024
|
||||
|
||||
def save_params(args, gen_kwargs):
|
||||
save_path = os.path.join(args.image_dir, "caption_cog_params.txt")
|
||||
args_dict = {
|
||||
"args": vars(args),
|
||||
"gen_kwargs": gen_kwargs,
|
||||
}
|
||||
pretty_print = json.dumps(args_dict, indent=4)
|
||||
with open(save_path, "w") as f:
|
||||
f.write(pretty_print)
|
||||
|
||||
|
||||
def main(args):
|
||||
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'THUDM/cogvlm-chat-hf',
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
|
||||
#revision=... # no one is actually doing this
|
||||
load_in_4bit=not args.disable_4bit,
|
||||
)
|
||||
|
||||
|
@ -61,8 +135,8 @@ def main(args):
|
|||
args.temp = args.temp or 1.0
|
||||
|
||||
args.append = args.append or ""
|
||||
if len(args.append) > 0 and not args.append.startswith(" "):
|
||||
args.append = " " + args.append
|
||||
if len(args.append) > 0:
|
||||
args.append = " " + args.append.strip()
|
||||
|
||||
gen_kwargs = {
|
||||
"max_length": args.max_length,
|
||||
|
@ -80,52 +154,61 @@ def main(args):
|
|||
}
|
||||
|
||||
if args.max_new_tokens is not None:
|
||||
print(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
del gen_kwargs["max_length"]
|
||||
|
||||
if not do_sample:
|
||||
print(f"** Using greedy search instead sampling. Generated captions will be deterministic; meaning it will be the same even if you run this program multiple times.")
|
||||
logging.info(f"** Using greedy sampling")
|
||||
del gen_kwargs["top_k"]
|
||||
del gen_kwargs["top_p"]
|
||||
del gen_kwargs["temperature"]
|
||||
else:
|
||||
print(f"** Sampling enabled")
|
||||
logging.info(f"** Sampling enabled")
|
||||
|
||||
force_words_ids = None
|
||||
if args.force_words is not None:
|
||||
force_words = args.force_words.split(",") if args.force_words is not None else []
|
||||
print(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}")
|
||||
logging.info(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}")
|
||||
force_words_ids = tokenizer(force_words, add_special_tokens=False)["input_ids"] if force_words else []
|
||||
|
||||
bad_words_ids = None
|
||||
if args.bad_words is not None:
|
||||
bad_words = args.bad_words.split(",") if args.bad_words is not None else []
|
||||
print(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
|
||||
logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
|
||||
bad_words_ids = tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else []
|
||||
|
||||
print(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")
|
||||
logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")
|
||||
|
||||
save_params(args, gen_kwargs)
|
||||
|
||||
total_start_time = time.time()
|
||||
i_processed = 0
|
||||
|
||||
for image_path in image_generator(args.image_dir):
|
||||
starts_with = args.starts_with.strip()
|
||||
|
||||
for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)):
|
||||
candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt")
|
||||
|
||||
if args.no_overwrite and os.path.exists(candidate_caption_path):
|
||||
print(f"Skipping {image_path}, caption already exists.")
|
||||
logging.warning(f"Skipping {image_path}, caption already exists.")
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
cap_start_time = time.time()
|
||||
image = Image.open(image_path)
|
||||
|
||||
try:
|
||||
image = image.convert('RGB')
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception as e:
|
||||
print(f"Non-fatal error processing {image_path}: {e}")
|
||||
logging.warning(f"Non-fatal error processing {image_path}: {e}")
|
||||
continue
|
||||
|
||||
logging.debug(f" __ Prompt before plugin: {Fore.LIGHTGREEN_EX}{args.prompt}{Style.RESET_ALL}")
|
||||
prompt = prompt_plugin_fn(image_path, args=args)
|
||||
logging.debug(f" __ Modified prompt after plugin: {Fore.LIGHTGREEN_EX}{prompt}{Style.RESET_ALL}")
|
||||
|
||||
inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode
|
||||
|
||||
inputs = model.build_conversation_input_ids(tokenizer, query=args.prompt, history=[], images=[image]) # chat mode
|
||||
inputs = {
|
||||
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
||||
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
||||
|
@ -134,27 +217,53 @@ def main(args):
|
|||
}
|
||||
|
||||
with torch.no_grad():
|
||||
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
#logging.debug(f"inputs decoded: {input_decoded}")
|
||||
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
outputs_without_prompt = outputs[:, inputs['input_ids'].shape[1]:]
|
||||
|
||||
len_inputs = inputs['input_ids'].shape[1]
|
||||
outputs_without_prompt = outputs[:, len_inputs:]
|
||||
|
||||
caption = tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
||||
if not args.remove_starts_with:
|
||||
# deal with caption starting with comma, etc
|
||||
if not re.match(r"^\W", caption):
|
||||
caption = starts_with + " " + caption
|
||||
else:
|
||||
caption = starts_with + caption
|
||||
|
||||
caption += args.append
|
||||
|
||||
with open(candidate_caption_path, "w", encoding="utf-8") as f:
|
||||
with open(candidate_caption_path, "w") as f:
|
||||
f.write(caption)
|
||||
vram_gb = get_gpu_memory_map()
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ")
|
||||
print(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}")
|
||||
elapsed_time = time.time() - cap_start_time
|
||||
logging.info(f"n:{i:05}, VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ")
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}")
|
||||
i_processed += 1
|
||||
|
||||
if i_processed == 0:
|
||||
print(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)")
|
||||
logging.info(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)")
|
||||
exit(1)
|
||||
|
||||
total_elapsed_time = time.time() - total_start_time
|
||||
avg_time = total_elapsed_time / i_processed
|
||||
hh_mm_ss = time.strftime("%H:%M:%S", time.gmtime(total_elapsed_time))
|
||||
print(f"** Done captioning {args.image_dir} with prompt '{args.prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image")
|
||||
logging.info(f"** Done captioning {args.image_dir} with prompt '{prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image")
|
||||
|
||||
|
||||
def configure_logging(args: argparse.Namespace):
|
||||
level = logging.INFO if not args.debug else logging.DEBUG
|
||||
filemode = "a" if args.append_log else "w"
|
||||
logging.basicConfig(filename="caption_cog.log",
|
||||
level=level,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
filemode=filemode)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter('%(message)s'))
|
||||
logging.getLogger('').addHandler(console)
|
||||
|
||||
EXAMPLES = """ex.
|
||||
Basic example:
|
||||
|
@ -189,6 +298,7 @@ DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL}
|
|||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
|
||||
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
|
||||
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
|
||||
|
@ -206,26 +316,34 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.")
|
||||
argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.")
|
||||
argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'")
|
||||
argparser.add_argument("--no_recurse", action="store_true", help="Do not recurse into subdirectories.")
|
||||
argparser.add_argument("--prompt_plugin", type=str, default=None, help="Function name to modify prompt, edit code to add plugins.")
|
||||
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
|
||||
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
||||
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
||||
args = argparser.parse_args()
|
||||
|
||||
configure_logging(args)
|
||||
|
||||
print(DESCRIPTION)
|
||||
print(EXAMPLES)
|
||||
|
||||
if args.top_k is not None or args.top_p is not None or args.temp is not None:
|
||||
print(f"** Sampling enabled.")
|
||||
args.sampling = True
|
||||
args.top_k = args.top_k or 50
|
||||
args.top_p = args.top_p or 1.0
|
||||
args.temp = args.temp or 1.0
|
||||
|
||||
if args.image_dir is None:
|
||||
print(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}")
|
||||
logging.error(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}")
|
||||
exit(1)
|
||||
|
||||
if not os.path.exists(args.image_dir):
|
||||
print(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}")
|
||||
logging.error(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}")
|
||||
exit(1)
|
||||
|
||||
print(f"** Running: {args.image_dir} with prompt '{args.prompt}'")
|
||||
startprint = f"** Running: {args.image_dir} with prompt '{args.prompt}"
|
||||
if args.starts_with is not None:
|
||||
startprint += f" {args.starts_with}'"
|
||||
else:
|
||||
startprint += "'"
|
||||
startprint += f" <caption>"
|
||||
if args.append is not None:
|
||||
startprint += f", and appending: {args.append}"
|
||||
logging.info(startprint)
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
from argparse import Namespace
|
||||
from typing import List
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from colorama import Fore, Style
|
||||
import importlib, pkgutil
|
||||
|
||||
class TestBase():
|
||||
def __init__(self):
|
||||
self.a = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TestBase: {self.a}"
|
||||
|
||||
class TestSub(TestBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.b = 2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TestSub: {self.a}, {self.b}"
|
||||
|
||||
class PromptIdentityPlugin():
|
||||
"""
|
||||
Base class for prompt alternation plugins, useful for captioning, etc.
|
||||
"""
|
||||
def __init__(self, description: str="identity", key: str="indentity_plugin", fn: callable=None, args: Namespace=None):
|
||||
self.description = description
|
||||
#print(f"PromptIdentityPlugin: __init__ with fn: {fn}")
|
||||
if fn is None:
|
||||
fn = self._prompt_identity_from_args
|
||||
#print(f"{self.__class__}: fn is None, setting to self._prompt_identity_from_args")
|
||||
self.fn = fn
|
||||
self._key = key
|
||||
self.args = args
|
||||
#print(f"self._key: {self._key}")
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
return self._key
|
||||
|
||||
def _prompt_identity_from_args(self, args: Namespace) -> str:
|
||||
#print("Wat")
|
||||
if "prompt" not in args:
|
||||
raise ValueError(f"prompt is required for prompt_identity_from_args")
|
||||
#print(f"prompt: {args.prompt}")
|
||||
#print(f"{type(args)}, type(prompt): {type(args.prompt)}")
|
||||
return args.prompt
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Plugin Function: \"{self.key}\" - {self.description}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
def __call__(self, image_path, args: Namespace) -> str:
|
||||
#print(f"Calling {self.key} with image_path: {image_path}, args: {args}")
|
||||
args.image_path = image_path
|
||||
return self.fn(args)
|
||||
|
||||
@staticmethod
|
||||
def _add_hint_to_prompt(hint: str, prompt: str) -> str:
|
||||
if "\{hint\}" in prompt:
|
||||
prompt = prompt.replace("\{hint\}", hint)
|
||||
else:
|
||||
prompt = f"Hint: {hint}\n{prompt}"
|
||||
return prompt
|
||||
|
||||
class HintFromFilename(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="hint_from_filename",
|
||||
description="Add a hint to the prompt using the filename of the image (without extension)",
|
||||
fn=self._from_filename,
|
||||
args=args)
|
||||
|
||||
def _from_filename(self, args: Namespace) -> str:
|
||||
image_path = args.get("image_path", "")
|
||||
filename = os.path.splitext(image_path)[0]
|
||||
prompt = self._add_hint_to_prompt(filename, prompt)
|
||||
return prompt
|
||||
|
||||
class RemoveUsingCSV(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="remove_using_csv",
|
||||
description="Removes whole word matches of the csv passed in from the prompt",
|
||||
fn=self._remove_using_csv,
|
||||
args=args)
|
||||
|
||||
def _filter_logic(self, prompt: str, filters: List[str]) -> str:
|
||||
# word boundary filter
|
||||
pattern = r'\b(?:' + '|'.join([re.escape(word) for word in filters]) + r')\b'
|
||||
|
||||
result = re.sub(pattern, '', prompt)
|
||||
|
||||
# fix up extra space and punctuation
|
||||
result = re.sub(r'\s{2,}', ' ', result) # Remove extra spaces
|
||||
result = re.sub(r'\s([,.!?;])', r'\1', result) # Fix punctuation and spaces
|
||||
|
||||
return result.strip()
|
||||
|
||||
def _remove_using_csv(self, args: Namespace) -> str:
|
||||
prompt = args.prompt
|
||||
csv = args.csv
|
||||
if len(csv) == 0:
|
||||
logging.error(f"** {Fore.RED}Error: csv is required for remove_using_csv{Style.RESET_ALL}")
|
||||
else:
|
||||
words = csv.split(",")
|
||||
for word in words:
|
||||
prompt = self._filter_logic(prompt, [word])
|
||||
return prompt
|
||||
|
||||
class HintFromLeafDirectory(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="from_leaf_directory",
|
||||
description="Adds a hint to the prompt using the leaf directory name (last folder in path)",
|
||||
fn=self._from_leaf_directory,
|
||||
args=args)
|
||||
|
||||
def _from_leaf_directory(self, args:Namespace) -> str:
|
||||
image_path = args.image_path
|
||||
prompt = args.prompt
|
||||
leaf_folder_of_image = os.path.basename(os.path.dirname(image_path))
|
||||
return self._add_hint_to_prompt(leaf_folder_of_image, prompt)
|
||||
|
||||
class MetadataProvider():
|
||||
""" provides and caches metadata"""
|
||||
def __init__(self):
|
||||
self._datadict = {}
|
||||
|
||||
def _from_metadata(self, args) -> dict:
|
||||
image_path = args.get("image_path", "")
|
||||
prompt = args.get("prompt", "")
|
||||
metadata = self._get_metadata_dict(image_path)
|
||||
return f"metadata: {metadata}\n{prompt}"
|
||||
|
||||
def _get_metadata_dict(self, metadata_path: str) -> dict:
|
||||
if not self.loaded and not metadata_path in self.cache:
|
||||
metadata_dirname = os.path.dirname(metadata_path)
|
||||
if not os.path.exists(metadata_path):
|
||||
logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}")
|
||||
self._datadict[metadata_path] = {}
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self._datadict[metadata_path] = metadata
|
||||
|
||||
return self.dict[metadata_path]
|
||||
|
||||
class FromFolderMetadataJson(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
super().__init__(key="from_folder_metadata",
|
||||
description="Looks for metadata.json in the folder of the images",
|
||||
fn=self._from_metadata_json,
|
||||
args=args)
|
||||
self.metadata_provider = MetadataProvider()
|
||||
|
||||
def _from_metadata_json(self, args:Namespace) -> dict:
|
||||
image_path = args.image_path
|
||||
image_dir = os.path.dirname(image_path)
|
||||
metadata_json_path = os.path.join(image_dir, "metadata.json")
|
||||
self.metadata_provider._get_metadata_dict(metadata_json_path)
|
||||
|
||||
return ""
|
||||
|
||||
class TagsFromFolderMetadataJson(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
self.cache = {}
|
||||
super().__init__(key = "tags_from_metadata_json",
|
||||
description="Adds tags hint from metadata.json (in the samefolder as the image) to the prompt",
|
||||
fn=self._tags_from_metadata_json,
|
||||
args=args)
|
||||
self.metadata_provider = MetadataProvider()
|
||||
|
||||
def _tags_from_metadata_json(self, args:Namespace) -> str:
|
||||
image_path = args.image_path
|
||||
|
||||
current_dir = os.path.dirname(image_path)
|
||||
metadata_json_path = os.path.join(current_dir, "metadata.json")
|
||||
self.metadata_provider._get_metadata_dict(metadata_json_path).get("tags", [])
|
||||
|
||||
prompt = args.prompt
|
||||
if len(tags) > 0:
|
||||
tags = ", ".join(tags)
|
||||
return self._add_hint_to_prompt(f"tags: {tags}", prompt)
|
||||
return prompt
|
||||
|
||||
class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
|
||||
def __init__(self, args:Namespace=None):
|
||||
self.cache = {}
|
||||
super().__init__(key="title_and_tags_from_metadata_json",
|
||||
description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt",
|
||||
fn=self._title_and_tags_from_metadata_json,
|
||||
args=args)
|
||||
|
||||
def _title_and_tags_from_metadata_json(self, args:Namespace) -> str:
|
||||
prompt = args.prompt
|
||||
logging.debug(f" {self.key}: prompt before: {prompt}")
|
||||
image_path = args.image_path
|
||||
current_dir = os.path.dirname(image_path)
|
||||
metadata_json_path = os.path.join(current_dir, "metadata.json")
|
||||
|
||||
if metadata_json_path not in self.cache:
|
||||
if not os.path.exists(metadata_json_path):
|
||||
logging.error(f"** {Fore.RED}Error: metadata.json not found in {current_dir}, skippin prompt modification{Style.RESET_ALL}")
|
||||
return prompt
|
||||
with open(metadata_json_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.cache[metadata_json_path] = metadata
|
||||
|
||||
title = self.cache[metadata_json_path].get("title", "").strip()
|
||||
hint = f"title: {title}" if len(title) > 0 else ""
|
||||
|
||||
tags = self.cache[metadata_json_path].get("tags", [])
|
||||
tags = tags.split(",") if isinstance(tags, str) else tags # can be csv or list
|
||||
if len(tags) > 0:
|
||||
tags = ", ".join(tags)
|
||||
hint += f", tags: {tags}"
|
||||
|
||||
prompt = self._add_hint_to_prompt(hint, prompt)
|
||||
logging.debug(f" {self.key}: prompt after: {prompt}")
|
||||
return prompt
|
||||
|
||||
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin):
|
||||
"""
|
||||
Adds title and tags hint from global metadata json given by '--metadatafilename'
|
||||
Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful?
|
||||
"""
|
||||
def __init__(self, args:Namespace=None):
|
||||
self.cache = {}
|
||||
self.metadata_loaded = False
|
||||
super().__init__(key="title_and_tags_from_global_metadata_json",
|
||||
description="Adds title and tags hint from global metadata json given by '--metadatafilename mydata/somefile.json'",
|
||||
fn=self._title_and_tags_from_global_metadata_json,
|
||||
args=args)
|
||||
|
||||
def _title_and_tags_from_global_metadata_json(self, image_path: str, **kwargs) -> str:
|
||||
prompt = kwargs.get("prompt", "")
|
||||
metadata_json_path = kwargs.get("metadata_json_path", "")
|
||||
|
||||
if not self.metadata_loaded: # kinda sloppy but avoids me having to think about reworking init args
|
||||
if not os.path.exists(metadata_json_path):
|
||||
raise FileNotFoundError(f"metadata.json not found in {metadata_json_path}")
|
||||
with open(metadata_json_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.cache[metadata_json_path] = metadata
|
||||
self.metadata_loaded = True
|
||||
|
||||
title = self.cache[metadata_json_path].get("title", "")
|
||||
hint = f"title: {title}"
|
||||
|
||||
tags = self.cache[metadata_json_path].get("tags", [])
|
||||
if len(tags) > 0:
|
||||
tags = ", ".join(tags)
|
||||
hint += f", tags: {tags}"
|
||||
|
||||
return self._add_hint_to_prompt(hint, prompt)
|
||||
|
||||
def is_subclass_of_subclass(attribute, base_class, recursion_depth=5):
|
||||
if attribute.__module__ == base_class.__module__:
|
||||
if issubclass(attribute, base_class) and attribute is not base_class:
|
||||
return True
|
||||
|
||||
if recursion_depth == 0:
|
||||
return False
|
||||
recursion_depth -= 1
|
||||
for base in attribute.__bases__:
|
||||
if is_subclass_of_subclass(base, base_class, recursion_depth):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_prompt_alteration_plugin_list() -> list:
|
||||
plugins = []
|
||||
|
||||
for finder, name, ispkg in pkgutil.iter_modules(["plugins"]):
|
||||
plugins_module_name = f"plugins.{name}"
|
||||
|
||||
if plugins_module_name == "plugins.caption_plugins":
|
||||
module = importlib.import_module(plugins_module_name)
|
||||
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isinstance(attribute, type) \
|
||||
and attribute.__module__ == module.__name__ \
|
||||
and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \
|
||||
and attribute is not PromptIdentityPlugin:
|
||||
|
||||
plugins.append(attribute)
|
||||
#print(f"done checking plugins_module_name: {plugins_module_name}")
|
||||
return plugins
|
||||
|
||||
def load_prompt_alteration_plugin(plugin_key: str, args) -> callable:
|
||||
if plugin_key is not None:
|
||||
prompt_alteration_plugins = get_prompt_alteration_plugin_list()
|
||||
|
||||
for prompt_plugin_cls in prompt_alteration_plugins:
|
||||
plugin_instance = prompt_plugin_cls(args)
|
||||
#print(f"prompt_plugin_cls: {prompt_plugin_cls}")
|
||||
#print(f"prompt_plugin_cls.key: {prompt_plugin_cls.key}")
|
||||
if plugin_key == plugin_instance.key:
|
||||
logging.info(f" **** Found plugin: {plugin_instance.key}")
|
||||
return plugin_instance
|
||||
raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins")
|
||||
else:
|
||||
logging.info(f"No plugin specified")
|
||||
return PromptIdentityPlugin(args=args)
|
Loading…
Reference in New Issue