From b1b73838f012c67a0c3d118ddbe6bb249339b627 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Fri, 21 Jun 2024 15:27:19 -0400 Subject: [PATCH] add from_image_json caption plugin to add all metadata from image json to prompt --- plugins/caption_plugins.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/plugins/caption_plugins.py b/plugins/caption_plugins.py index b263346..1c90472 100644 --- a/plugins/caption_plugins.py +++ b/plugins/caption_plugins.py @@ -228,14 +228,40 @@ class TitleAndTagsFromImageJson(PromptIdentityBase): logging.debug(f" {self.key}: prompt after: {prompt}") return prompt -class VogueRunwayImageJson(PromptIdentityBase): +class AllMetadataFromImageJson(PromptIdentityBase): def __init__(self, args:Namespace=None): - super().__init__(key="vogue_runway_from_image_json", + super().__init__(key="from_image_json", description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt", fn=self._title_and_tags_from_metadata_json, args=args) - - def try_get_kvps(self, metadata, keys:list): + + def _title_and_tags_from_metadata_json(self, args:Namespace) -> str: + prompt = args.prompt + logging.debug(f" {self.key}: prompt before: {prompt}") + image_path = args.image_path + current_dir = os.path.dirname(image_path) + image_path_base = os.path.basename(image_path) + image_path_without_extension = os.path.splitext(image_path_base)[0] + candidate_json_path = os.path.join(current_dir, f"{image_path_without_extension}.json") + + if os.path.exists(candidate_json_path): + with open(candidate_json_path, "r") as f: + metadata = json.load(f) + + hint = json.dumps(metadata) + + prompt = self._add_hint_to_prompt(hint, prompt) + logging.debug(f" {self.key}: prompt after: {prompt}") + return prompt + +class VogueRunwayImageJson(PromptIdentityBase): + def __init__(self, args:Namespace=None): + super().__init__(key="vogue_runway_from_image_json", + description="Adds designer, season, category, year, and tags from [image_path_without_extension].json to the prompt", + fn=self._title_and_tags_from_metadata_json, + args=args) + + def _try_get_kvps(self, metadata, keys:list): values = [] for key in keys: val = metadata.get(key, "") @@ -261,10 +287,8 @@ class VogueRunwayImageJson(PromptIdentityBase): with open(candidate_json_path, "r") as f: metadata = json.load(f) - keys = ["designer","season","category","year"] - hint = "" - hint = self.try_get_kvps(metadata, keys) + hint = self._try_get_kvps(metadata, ["designer","season","category","year"]) tags = metadata.get("tags", []) tags = tags.split(",") if isinstance(tags, str) else tags # can be csv or list