Add static methods to ImageCaption for deriving captions from various sources
This commit is contained in:
parent
e4ed5ff063
commit
a6cabe8d7d
|
@ -18,6 +18,8 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
import yaml
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
@ -25,6 +27,9 @@ from torchvision import transforms
|
|||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||
|
||||
class ImageCaption:
|
||||
"""
|
||||
|
@ -60,13 +65,15 @@ class ImageCaption:
|
|||
:param seed used to initialize the randomizer
|
||||
:return: generated caption string
|
||||
"""
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
if self.__tags:
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
|
||||
def get_caption(self) -> str:
|
||||
|
@ -91,7 +98,10 @@ class ImageCaption:
|
|||
|
||||
weights_copy.pop(pos)
|
||||
tag = tags_copy.pop(pos)
|
||||
caption += ", " + tag
|
||||
|
||||
if caption:
|
||||
caption += ", "
|
||||
caption += tag
|
||||
|
||||
return caption
|
||||
|
||||
|
@ -100,6 +110,136 @@ class ImageCaption:
|
|||
random.Random(seed).shuffle(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses the file name to get the caption.
|
||||
|
||||
:param file_path: Path to the image file.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||
caption = file_name.split("_")[0]
|
||||
return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a text file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: Path to the text file.
|
||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a yaml file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: path to the yaml file
|
||||
:param default_caption: caption to return if the file does not exist or is invalid
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as stream:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Try to resolve a caption from a file path or return `default_caption`.
|
||||
|
||||
:string: The path to the file to parse.
|
||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||
|
||||
case ".txt" | ".caption":
|
||||
return ImageCaption.from_text_file(file_path, default_caption)
|
||||
|
||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||
file_path = file_path_without_ext + ext
|
||||
image_caption = ImageCaption.from_file(file_path)
|
||||
if image_caption is not None:
|
||||
return image_caption
|
||||
return ImageCaption.from_file_name(file_path)
|
||||
|
||||
case _:
|
||||
return default_caption
|
||||
else:
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def resolve(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Try to resolve a caption from a string. If the string is a file path,
|
||||
the caption will be read from the file, otherwise the string will be
|
||||
parsed as a caption.
|
||||
|
||||
:string: The string to resolve.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||
|
||||
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue