Add static methods to ImageCaption for deriving captions from various sources
This commit is contained in:
parent
e4ed5ff063
commit
a6cabe8d7d
|
@ -18,6 +18,8 @@ import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import typing
|
||||||
|
import yaml
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -25,6 +27,9 @@ from torchvision import transforms
|
||||||
|
|
||||||
_RANDOM_TRIM = 0.04
|
_RANDOM_TRIM = 0.04
|
||||||
|
|
||||||
|
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||||
|
|
||||||
|
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||||
|
|
||||||
class ImageCaption:
|
class ImageCaption:
|
||||||
"""
|
"""
|
||||||
|
@ -60,6 +65,7 @@ class ImageCaption:
|
||||||
:param seed used to initialize the randomizer
|
:param seed used to initialize the randomizer
|
||||||
:return: generated caption string
|
:return: generated caption string
|
||||||
"""
|
"""
|
||||||
|
if self.__tags:
|
||||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||||
|
|
||||||
if self.__use_weights:
|
if self.__use_weights:
|
||||||
|
@ -68,6 +74,7 @@ class ImageCaption:
|
||||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||||
|
|
||||||
return self.__main_prompt + ", " + tags_caption
|
return self.__main_prompt + ", " + tags_caption
|
||||||
|
return self.__main_prompt + ", " + tags_caption
|
||||||
|
|
||||||
def get_caption(self) -> str:
|
def get_caption(self) -> str:
|
||||||
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
||||||
|
@ -91,7 +98,10 @@ class ImageCaption:
|
||||||
|
|
||||||
weights_copy.pop(pos)
|
weights_copy.pop(pos)
|
||||||
tag = tags_copy.pop(pos)
|
tag = tags_copy.pop(pos)
|
||||||
caption += ", " + tag
|
|
||||||
|
if caption:
|
||||||
|
caption += ", "
|
||||||
|
caption += tag
|
||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
@ -100,6 +110,136 @@ class ImageCaption:
|
||||||
random.Random(seed).shuffle(tags)
|
random.Random(seed).shuffle(tags)
|
||||||
return ", ".join(tags)
|
return ", ".join(tags)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(string: str) -> 'ImageCaption':
|
||||||
|
"""
|
||||||
|
Parses a string to get the caption.
|
||||||
|
|
||||||
|
:param string: String to parse.
|
||||||
|
:return: `ImageCaption` object.
|
||||||
|
"""
|
||||||
|
split_caption = list(map(str.strip, string.split(",")))
|
||||||
|
main_prompt = split_caption[0]
|
||||||
|
tags = split_caption[1:]
|
||||||
|
tag_weights = [1.0] * len(tags)
|
||||||
|
|
||||||
|
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||||
|
"""
|
||||||
|
Parses the file name to get the caption.
|
||||||
|
|
||||||
|
:param file_path: Path to the image file.
|
||||||
|
:return: `ImageCaption` object.
|
||||||
|
"""
|
||||||
|
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||||
|
caption = file_name.split("_")[0]
|
||||||
|
return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||||
|
"""
|
||||||
|
Parses a text file to get the caption. Returns the default caption if
|
||||||
|
the file does not exist or is invalid.
|
||||||
|
|
||||||
|
:param file_path: Path to the text file.
|
||||||
|
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||||
|
:return: `ImageCaption` object or `None`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||||
|
caption_text = caption_file.read()
|
||||||
|
return ImageCaption.parse(caption_text)
|
||||||
|
except:
|
||||||
|
logging.error(f" *** Error reading {file_path} to get caption")
|
||||||
|
return default_caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||||
|
"""
|
||||||
|
Parses a yaml file to get the caption. Returns the default caption if
|
||||||
|
the file does not exist or is invalid.
|
||||||
|
|
||||||
|
:param file_path: path to the yaml file
|
||||||
|
:param default_caption: caption to return if the file does not exist or is invalid
|
||||||
|
:return: `ImageCaption` object or `None`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as stream:
|
||||||
|
file_content = yaml.safe_load(stream)
|
||||||
|
main_prompt = file_content.get("main_prompt", "")
|
||||||
|
rating = file_content.get("rating", 1.0)
|
||||||
|
unparsed_tags = file_content.get("tags", [])
|
||||||
|
|
||||||
|
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||||
|
|
||||||
|
tags = []
|
||||||
|
tag_weights = []
|
||||||
|
last_weight = None
|
||||||
|
weights_differ = False
|
||||||
|
for unparsed_tag in unparsed_tags:
|
||||||
|
tag = unparsed_tag.get("tag", "").strip()
|
||||||
|
if len(tag) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tags.append(tag)
|
||||||
|
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||||
|
tag_weights.append(tag_weight)
|
||||||
|
|
||||||
|
if last_weight is not None and weights_differ is False:
|
||||||
|
weights_differ = last_weight != tag_weight
|
||||||
|
|
||||||
|
last_weight = tag_weight
|
||||||
|
|
||||||
|
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||||
|
except:
|
||||||
|
logging.error(f" *** Error reading {file_path} to get caption")
|
||||||
|
return default_caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||||
|
"""
|
||||||
|
Try to resolve a caption from a file path or return `default_caption`.
|
||||||
|
|
||||||
|
:string: The path to the file to parse.
|
||||||
|
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||||
|
:return: `ImageCaption` object or `None`.
|
||||||
|
"""
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||||
|
match ext:
|
||||||
|
case ".yaml" | ".yml":
|
||||||
|
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||||
|
|
||||||
|
case ".txt" | ".caption":
|
||||||
|
return ImageCaption.from_text_file(file_path, default_caption)
|
||||||
|
|
||||||
|
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||||
|
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||||
|
file_path = file_path_without_ext + ext
|
||||||
|
image_caption = ImageCaption.from_file(file_path)
|
||||||
|
if image_caption is not None:
|
||||||
|
return image_caption
|
||||||
|
return ImageCaption.from_file_name(file_path)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return default_caption
|
||||||
|
else:
|
||||||
|
return default_caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve(string: str) -> 'ImageCaption':
|
||||||
|
"""
|
||||||
|
Try to resolve a caption from a string. If the string is a file path,
|
||||||
|
the caption will be read from the file, otherwise the string will be
|
||||||
|
parsed as a caption.
|
||||||
|
|
||||||
|
:string: The string to resolve.
|
||||||
|
:return: `ImageCaption` object.
|
||||||
|
"""
|
||||||
|
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||||
|
|
||||||
|
|
||||||
class ImageTrainItem:
|
class ImageTrainItem:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue