Support more control regarding caption tag shuffeling using yaml files
This commit is contained in:
parent
bf869db2e2
commit
a3618409bc
|
@ -18,7 +18,7 @@ import os
|
|||
import logging
|
||||
from PIL import Image
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import data.aspects as aspects
|
||||
from colorama import Fore, Style
|
||||
import zipfile
|
||||
|
@ -76,17 +76,30 @@ class DataLoaderMultiAspect():
|
|||
return self.image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_file(file_path, fallback_caption):
|
||||
caption = fallback_caption
|
||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption = caption_file.read()
|
||||
caption_text = caption_file.read()
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
caption = fallback_caption
|
||||
pass
|
||||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||
"""
|
||||
Splits a string by "," into the main prompt and additional tags with equal weights
|
||||
"""
|
||||
split_caption = caption_string.split(",")
|
||||
main_prompt = split_caption.pop(0).strip()
|
||||
tags = []
|
||||
for tag in split_caption:
|
||||
tags.append(tag.strip())
|
||||
|
||||
return ImageCaption(main_prompt, tags, [1.0] * len(tags))
|
||||
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0):
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
|
@ -95,16 +108,15 @@ class DataLoaderMultiAspect():
|
|||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
|
||||
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
||||
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
|
||||
|
||||
if os.path.exists(txt_file_path):
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption_from_filename)
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||
elif os.path.exists(caption_file_path):
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption_from_filename)
|
||||
else:
|
||||
caption = caption_from_filename
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||
|
||||
try:
|
||||
image = Image.open(pathname)
|
||||
|
|
|
@ -134,16 +134,15 @@ class EveryDreamBatch(Dataset):
|
|||
]
|
||||
)
|
||||
|
||||
if self.shuffle_tags and "," in train_item['caption']:
|
||||
tags = train_item["caption"].split(",")
|
||||
random.Random(self.seed).shuffle(tags)
|
||||
self.seed += 1
|
||||
train_item["caption"] = ", ".join(tags)
|
||||
if self.shuffle_tags:
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||
else:
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
if random.random() > self.conditional_dropout:
|
||||
example["tokens"] = self.tokenizer(train_item["caption"],
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
|
@ -156,7 +155,7 @@ class EveryDreamBatch(Dataset):
|
|||
).input_ids
|
||||
|
||||
example["tokens"] = torch.tensor(example["tokens"])
|
||||
example["caption"] = train_item["caption"] # for sampling if needed
|
||||
|
||||
example["runt_size"] = train_item["runt_size"]
|
||||
|
||||
return example
|
||||
|
|
|
@ -13,25 +13,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import PIL
|
||||
import numpy as np
|
||||
from torchvision import transforms, utils
|
||||
import random
|
||||
import bisect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
class ImageTrainItem():
|
||||
|
||||
class ImageCaption:
|
||||
"""
|
||||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float]):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
:param tags: list of tags to pick from to fill the caption
|
||||
:param tag_weights: weights to indicate which tags are more desired and should be picked preferably
|
||||
"""
|
||||
self.__main_prompt = main_prompt
|
||||
self.__tags = tags
|
||||
self.__tag_weights = tag_weights
|
||||
if len(tags) > len(tag_weights):
|
||||
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
|
||||
|
||||
def get_shuffled_caption(self, seed: int, target_length=150) -> str:
|
||||
"""
|
||||
returns the caption a string with a random selection of the tags in random order
|
||||
:param seed used to initialize the randomizer
|
||||
:param target_length: maximum desired length of the caption
|
||||
:return: generated caption string
|
||||
"""
|
||||
target_tag_length = target_length - len(self.__main_prompt)
|
||||
tags_caption = self.__get_tags_caption(seed, self.__tags, self.__tag_weights, target_tag_length)
|
||||
|
||||
return self.__main_prompt + tags_caption
|
||||
|
||||
def get_caption(self) -> str:
|
||||
return self.__main_prompt + ", ".join(self.__tags)
|
||||
|
||||
@staticmethod
|
||||
def __get_tags_caption(seed: int, tags: list[str], weights: list[float], target_length: int) -> str:
|
||||
caption = ""
|
||||
|
||||
picker = random.Random(seed)
|
||||
tags_copy = tags.copy()
|
||||
weights_copy = weights.copy()
|
||||
|
||||
while len(tags_copy) != 0 and len(caption) < target_length:
|
||||
cum_weights = []
|
||||
weight_sum = 0.0
|
||||
for weight in weights_copy:
|
||||
weight_sum += weight
|
||||
cum_weights.append(weight_sum)
|
||||
|
||||
point = picker.uniform(0, weight_sum)
|
||||
pos = bisect.bisect_left(cum_weights, point)
|
||||
|
||||
weights_copy.pop(pos)
|
||||
tag = tags_copy.pop(pos)
|
||||
caption += ", " + tag
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
class ImageTrainItem():
|
||||
"""
|
||||
image: PIL.Image
|
||||
identifier: caption,
|
||||
target_aspect: (width, height),
|
||||
pathname: path to image file
|
||||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
|
||||
"""
|
||||
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
self.pathname = pathname
|
||||
|
@ -50,50 +111,50 @@ class ImageTrainItem():
|
|||
save: save the cropped image to disk, for manual inspection of resize/crop
|
||||
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
|
||||
"""
|
||||
#print(self.pathname, self.image)
|
||||
# print(self.pathname, self.image)
|
||||
try:
|
||||
#if not hasattr(self, 'image'):
|
||||
# if not hasattr(self, 'image'):
|
||||
self.image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
|
||||
width, height = self.image.size
|
||||
if crop:
|
||||
if crop:
|
||||
cropped_img = self.__autocrop(self.image)
|
||||
self.image = cropped_img.resize((512,512), resample=PIL.Image.BICUBIC)
|
||||
self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC)
|
||||
else:
|
||||
width, height = self.image.size
|
||||
jitter_amount = random.randint(0,crop_jitter)
|
||||
jitter_amount = random.randint(0, crop_jitter)
|
||||
|
||||
if self.target_wh[0] == self.target_wh[1]:
|
||||
if width > height:
|
||||
left = random.randint(0, width - height)
|
||||
self.image = self.image.crop((left, 0, height+left, height))
|
||||
self.image = self.image.crop((left, 0, height + left, height))
|
||||
width = height
|
||||
elif height > width:
|
||||
top = random.randint(0, height - width)
|
||||
self.image = self.image.crop((0, top, width, width+top))
|
||||
self.image = self.image.crop((0, top, width, width + top))
|
||||
height = width
|
||||
elif width > self.target_wh[0]:
|
||||
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width-self.target_wh[0])
|
||||
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
|
||||
slicew_ratio = random.random()
|
||||
left = int(slice*slicew_ratio)
|
||||
right = width-int(slice*(1-slicew_ratio))
|
||||
left = int(slice * slicew_ratio)
|
||||
right = width - int(slice * (1 - slicew_ratio))
|
||||
sliceh_ratio = random.random()
|
||||
top = int(slice*sliceh_ratio)
|
||||
bottom = height- int(slice*(1-sliceh_ratio))
|
||||
top = int(slice * sliceh_ratio)
|
||||
bottom = height - int(slice * (1 - sliceh_ratio))
|
||||
|
||||
self.image = self.image.crop((left, top, right, bottom))
|
||||
else:
|
||||
image_aspect = width / height
|
||||
else:
|
||||
image_aspect = width / height
|
||||
target_aspect = self.target_wh[0] / self.target_wh[1]
|
||||
if image_aspect > target_aspect:
|
||||
new_width = int(height * target_aspect)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(width-new_width)/2)), 0)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
|
||||
left = jitter_amount
|
||||
right = left + new_width
|
||||
self.image = self.image.crop((left, 0, right, height))
|
||||
else:
|
||||
new_height = int(width / target_aspect)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(height-new_height)/2)), 0)
|
||||
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
|
||||
top = jitter_amount
|
||||
bottom = top + new_height
|
||||
self.image = self.image.crop((0, top, width, bottom))
|
||||
|
@ -106,17 +167,17 @@ class ImageTrainItem():
|
|||
exit()
|
||||
|
||||
if type(self.image) is not np.ndarray:
|
||||
if save:
|
||||
if save:
|
||||
base_name = os.path.basename(self.pathname)
|
||||
if not os.path.exists("test/output"):
|
||||
os.makedirs("test/output")
|
||||
self.image.save(f"test/output/{base_name}")
|
||||
|
||||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
#self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
#print(self.image.shape)
|
||||
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
# print(self.image.shape)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -128,25 +189,25 @@ class ImageTrainItem():
|
|||
x, y = image.size
|
||||
|
||||
if x != y:
|
||||
if (x>y):
|
||||
rand_x = x-y
|
||||
sigma = max(rand_x*q,1)
|
||||
if (x > y):
|
||||
rand_x = x - y
|
||||
sigma = max(rand_x * q, 1)
|
||||
else:
|
||||
rand_y = y-x
|
||||
sigma = max(rand_y*q,1)
|
||||
rand_y = y - x
|
||||
sigma = max(rand_y * q, 1)
|
||||
|
||||
if (x>y):
|
||||
if (x > y):
|
||||
x_crop_gauss = abs(random.gauss(0, sigma))
|
||||
x_crop = min(x_crop_gauss,(x-y)/2)
|
||||
x_crop = min(x_crop_gauss, (x - y) / 2)
|
||||
x_crop = math.trunc(x_crop)
|
||||
y_crop = 0
|
||||
else:
|
||||
y_crop_gauss = abs(random.gauss(0, sigma))
|
||||
x_crop = 0
|
||||
y_crop = min(y_crop_gauss,(y-x)/2)
|
||||
y_crop = min(y_crop_gauss, (y - x) / 2)
|
||||
y_crop = math.trunc(y_crop)
|
||||
|
||||
|
||||
min_xy = min(x, y)
|
||||
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
||||
|
||||
return image
|
||||
return image
|
||||
|
|
Loading…
Reference in New Issue