Add static methods to ImageCaption for deriving captions from various sources

This commit is contained in:
Joel Holdbrooks 2023-01-22 16:13:50 -08:00
parent e4ed5ff063
commit a6cabe8d7d
1 changed files with 147 additions and 7 deletions

View File

@ -18,6 +18,8 @@ import logging
import math
import os
import random
import typing
import yaml
import PIL
import numpy as np
@ -25,6 +27,9 @@ from torchvision import transforms
_RANDOM_TRIM = 0.04
DEFAULT_MAX_CAPTION_LENGTH = 2048
OptionalImageCaption = typing.Optional['ImageCaption']
class ImageCaption:
"""
@ -60,13 +65,15 @@ class ImageCaption:
:param seed used to initialize the randomizer
:return: generated caption string
"""
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
if self.__tags:
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
return self.__main_prompt + ", " + tags_caption
return self.__main_prompt + ", " + tags_caption
def get_caption(self) -> str:
@ -91,7 +98,10 @@ class ImageCaption:
weights_copy.pop(pos)
tag = tags_copy.pop(pos)
caption += ", " + tag
if caption:
caption += ", "
caption += tag
return caption
@ -100,6 +110,136 @@ class ImageCaption:
random.Random(seed).shuffle(tags)
return ", ".join(tags)
@staticmethod
def parse(string: str) -> 'ImageCaption':
"""
Parses a string to get the caption.
:param string: String to parse.
:return: `ImageCaption` object.
"""
split_caption = list(map(str.strip, string.split(",")))
main_prompt = split_caption[0]
tags = split_caption[1:]
tag_weights = [1.0] * len(tags)
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_file_name(file_path: str) -> 'ImageCaption':
"""
Parses the file name to get the caption.
:param file_path: Path to the image file.
:return: `ImageCaption` object.
"""
(file_name, _) = os.path.splitext(os.path.basename(file_path))
caption = file_name.split("_")[0]
return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a text file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: Path to the text file.
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
return ImageCaption.parse(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a yaml file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: path to the yaml file
:param default_caption: caption to return if the file does not exist or is invalid
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, "r") as stream:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Try to resolve a caption from a file path or return `default_caption`.
:string: The path to the file to parse.
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
if os.path.exists(file_path):
(file_path_without_ext, ext) = os.path.splitext(file_path)
match ext:
case ".yaml" | ".yml":
return ImageCaption.from_yaml_file(file_path, default_caption)
case ".txt" | ".caption":
return ImageCaption.from_text_file(file_path, default_caption)
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
for ext in [".yaml", ".yml", ".txt", ".caption"]:
file_path = file_path_without_ext + ext
image_caption = ImageCaption.from_file(file_path)
if image_caption is not None:
return image_caption
return ImageCaption.from_file_name(file_path)
case _:
return default_caption
else:
return default_caption
@staticmethod
def resolve(string: str) -> 'ImageCaption':
"""
Try to resolve a caption from a string. If the string is a file path,
the caption will be read from the file, otherwise the string will be
parsed as a caption.
:string: The string to resolve.
:return: `ImageCaption` object.
"""
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
class ImageTrainItem:
"""