Support more control regarding caption tag shuffeling using yaml files

This commit is contained in:
Jan Gerritsen 2023-01-07 17:29:09 +01:00
parent bf869db2e2
commit a3618409bc
3 changed files with 126 additions and 54 deletions

View File

@ -18,7 +18,7 @@ import os
import logging
from PIL import Image
import random
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, ImageCaption
import data.aspects as aspects
from colorama import Fore, Style
import zipfile
@ -76,17 +76,30 @@ class DataLoaderMultiAspect():
return self.image_caption_pairs
@staticmethod
def __read_caption_from_file(file_path, fallback_caption):
caption = fallback_caption
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption = caption_file.read()
caption_text = caption_file.read()
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
"""
Splits a string by "," into the main prompt and additional tags with equal weights
"""
split_caption = caption_string.split(",")
main_prompt = split_caption.pop(0).strip()
tags = []
for tag in split_caption:
tags.append(tag.strip())
return ImageCaption(main_prompt, tags, [1.0] * len(tags))
def __prescan_images(self, image_paths: list, flip_p=0.0):
"""
Create ImageTrainItem objects with metadata for hydration later
@ -95,16 +108,15 @@ class DataLoaderMultiAspect():
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
if os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption_from_filename)
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption_from_filename)
else:
caption = caption_from_filename
caption = self.__read_caption_from_file(caption_file_path, caption)
try:
image = Image.open(pathname)

View File

@ -134,16 +134,15 @@ class EveryDreamBatch(Dataset):
]
)
if self.shuffle_tags and "," in train_item['caption']:
tags = train_item["caption"].split(",")
random.Random(self.seed).shuffle(tags)
self.seed += 1
train_item["caption"] = ", ".join(tags)
if self.shuffle_tags:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
else:
example["caption"] = train_item["caption"].get_caption()
example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout:
example["tokens"] = self.tokenizer(train_item["caption"],
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
@ -156,7 +155,7 @@ class EveryDreamBatch(Dataset):
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
example["caption"] = train_item["caption"] # for sampling if needed
example["runt_size"] = train_item["runt_size"]
return example

View File

@ -13,25 +13,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import PIL
import numpy as np
from torchvision import transforms, utils
import random
import bisect
import logging
import math
import os
import logging
import random
import PIL
import numpy as np
from torchvision import transforms
_RANDOM_TRIM = 0.04
class ImageTrainItem():
class ImageCaption:
"""
Represents the various parts of an image caption
"""
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float]):
"""
:param main_prompt: The part of the caption which should always be included
:param tags: list of tags to pick from to fill the caption
:param tag_weights: weights to indicate which tags are more desired and should be picked preferably
"""
self.__main_prompt = main_prompt
self.__tags = tags
self.__tag_weights = tag_weights
if len(tags) > len(tag_weights):
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
def get_shuffled_caption(self, seed: int, target_length=150) -> str:
"""
returns the caption a string with a random selection of the tags in random order
:param seed used to initialize the randomizer
:param target_length: maximum desired length of the caption
:return: generated caption string
"""
target_tag_length = target_length - len(self.__main_prompt)
tags_caption = self.__get_tags_caption(seed, self.__tags, self.__tag_weights, target_tag_length)
return self.__main_prompt + tags_caption
def get_caption(self) -> str:
return self.__main_prompt + ", ".join(self.__tags)
@staticmethod
def __get_tags_caption(seed: int, tags: list[str], weights: list[float], target_length: int) -> str:
caption = ""
picker = random.Random(seed)
tags_copy = tags.copy()
weights_copy = weights.copy()
while len(tags_copy) != 0 and len(caption) < target_length:
cum_weights = []
weight_sum = 0.0
for weight in weights_copy:
weight_sum += weight
cum_weights.append(weight_sum)
point = picker.uniform(0, weight_sum)
pos = bisect.bisect_left(cum_weights, point)
weights_copy.pop(pos)
tag = tags_copy.pop(pos)
caption += ", " + tag
return caption
class ImageTrainItem():
"""
image: PIL.Image
identifier: caption,
target_aspect: (width, height),
pathname: path to image file
flip_p: probability of flipping image (0.0 to 1.0)
"""
def __init__(self, image: PIL.Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
self.caption = caption
self.target_wh = target_wh
self.pathname = pathname
@ -50,50 +111,50 @@ class ImageTrainItem():
save: save the cropped image to disk, for manual inspection of resize/crop
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
"""
#print(self.pathname, self.image)
# print(self.pathname, self.image)
try:
#if not hasattr(self, 'image'):
# if not hasattr(self, 'image'):
self.image = PIL.Image.open(self.pathname).convert('RGB')
width, height = self.image.size
if crop:
if crop:
cropped_img = self.__autocrop(self.image)
self.image = cropped_img.resize((512,512), resample=PIL.Image.BICUBIC)
self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC)
else:
width, height = self.image.size
jitter_amount = random.randint(0,crop_jitter)
jitter_amount = random.randint(0, crop_jitter)
if self.target_wh[0] == self.target_wh[1]:
if width > height:
left = random.randint(0, width - height)
self.image = self.image.crop((left, 0, height+left, height))
self.image = self.image.crop((left, 0, height + left, height))
width = height
elif height > width:
top = random.randint(0, height - width)
self.image = self.image.crop((0, top, width, width+top))
self.image = self.image.crop((0, top, width, width + top))
height = width
elif width > self.target_wh[0]:
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width-self.target_wh[0])
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
slicew_ratio = random.random()
left = int(slice*slicew_ratio)
right = width-int(slice*(1-slicew_ratio))
left = int(slice * slicew_ratio)
right = width - int(slice * (1 - slicew_ratio))
sliceh_ratio = random.random()
top = int(slice*sliceh_ratio)
bottom = height- int(slice*(1-sliceh_ratio))
top = int(slice * sliceh_ratio)
bottom = height - int(slice * (1 - sliceh_ratio))
self.image = self.image.crop((left, top, right, bottom))
else:
image_aspect = width / height
else:
image_aspect = width / height
target_aspect = self.target_wh[0] / self.target_wh[1]
if image_aspect > target_aspect:
new_width = int(height * target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(width-new_width)/2)), 0)
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
left = jitter_amount
right = left + new_width
self.image = self.image.crop((left, 0, right, height))
else:
new_height = int(width / target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(height-new_height)/2)), 0)
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
top = jitter_amount
bottom = top + new_height
self.image = self.image.crop((0, top, width, bottom))
@ -106,17 +167,17 @@ class ImageTrainItem():
exit()
if type(self.image) is not np.ndarray:
if save:
if save:
base_name = os.path.basename(self.pathname)
if not os.path.exists("test/output"):
os.makedirs("test/output")
self.image.save(f"test/output/{base_name}")
self.image = np.array(self.image).astype(np.uint8)
#self.image = (self.image / 127.5 - 1.0).astype(np.float32)
#print(self.image.shape)
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
# print(self.image.shape)
return self
@ -128,25 +189,25 @@ class ImageTrainItem():
x, y = image.size
if x != y:
if (x>y):
rand_x = x-y
sigma = max(rand_x*q,1)
if (x > y):
rand_x = x - y
sigma = max(rand_x * q, 1)
else:
rand_y = y-x
sigma = max(rand_y*q,1)
rand_y = y - x
sigma = max(rand_y * q, 1)
if (x>y):
if (x > y):
x_crop_gauss = abs(random.gauss(0, sigma))
x_crop = min(x_crop_gauss,(x-y)/2)
x_crop = min(x_crop_gauss, (x - y) / 2)
x_crop = math.trunc(x_crop)
y_crop = 0
else:
y_crop_gauss = abs(random.gauss(0, sigma))
x_crop = 0
y_crop = min(y_crop_gauss,(y-x)/2)
y_crop = min(y_crop_gauss, (y - x) / 2)
y_crop = math.trunc(y_crop)
min_xy = min(x, y)
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image
return image