revert multiline txt for now due to bug

This commit is contained in:
Victor Hall 2023-02-28 21:14:19 -05:00
parent c446194599
commit 8abef6bc74
2 changed files with 42 additions and 29 deletions

View File

@ -26,6 +26,7 @@ import torch.nn.functional as F
class EveryDreamBatch(Dataset):
"""
data_loader: `DataLoaderMultiAspect` object
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
conditional_dropout: probability of dropping the caption for a given image
crop_jitter: number of pixels to jitter the crop by, only for non-square images
seed: random seed
@ -37,6 +38,7 @@ class EveryDreamBatch(Dataset):
crop_jitter=20,
seed=555,
tokenizer=None,
retain_contrast=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
@ -50,6 +52,7 @@ class EveryDreamBatch(Dataset):
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
self.max_token_length = self.tokenizer.model_max_length
self.retain_contrast = retain_contrast
self.shuffle_tags = shuffle_tags
self.seed = seed
self.rated_dataset = rated_dataset
@ -78,19 +81,26 @@ class EveryDreamBatch(Dataset):
def __getitem__(self, i):
example = {}
train_item = self.__get_image_for_trainer(self.image_train_items[i])
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
if self.retain_contrast:
std_dev = 1.0
mean = 0.0
else:
std_dev = 0.5
mean = 0.5
image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.Normalize([mean], [std_dev]),
]
)
if self.shuffle_tags:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
else:
example["caption"] = train_item["caption"].get_caption(self.seed)
example["caption"] = train_item["caption"].get_caption()
example["image"] = image_transforms(train_item["image"])
@ -113,10 +123,11 @@ class EveryDreamBatch(Dataset):
return example
def __get_image_for_trainer(self, image_train_item: ImageTrainItem):
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
example = {}
save = debug_level > 2
image_train_tmp = image_train_item.hydrate(crop=False, crop_jitter=self.crop_jitter)
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak

View File

@ -37,7 +37,7 @@ class ImageCaption:
Represents the various parts of an image caption
"""
def __init__(self, main_prompts: list[str], rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
"""
:param main_prompt: The part of the caption which should always be included
:param tags: list of tags to pick from to fill the caption
@ -45,7 +45,7 @@ class ImageCaption:
:param max_target_length: The desired maximum length of a generated caption
:param use_weights: if ture, weights are considered when shuffling tags
"""
self.__main_prompts = main_prompts
self.__main_prompt = main_prompt
self.__rating = rating
self.__tags = tags
self.__tag_weights = tag_weights
@ -66,25 +66,21 @@ class ImageCaption:
:param seed used to initialize the randomizer
:return: generated caption string
"""
rng = random.Random(seed)
main_prompt = rng.choice(self.__main_prompts)
if self.__tags:
max_target_tag_length = self.__max_target_length - len(main_prompt)
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
return main_prompt + ", " + tags_caption
return main_prompt
return self.__main_prompt + ", " + tags_caption
return self.__main_prompt
def get_caption(self, seed) -> str:
rng = random.Random(seed)
main_prompt = rng.choice(self.__main_prompts)
if self.__tags:
return main_prompt + ", " + ", ".join(self.__tags)
return main_prompt
def get_caption(self) -> str:
if self.__tags:
return self.__main_prompt + ", " + ", ".join(self.__tags)
return self.__main_prompt
@staticmethod
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
@ -118,23 +114,19 @@ class ImageCaption:
return ", ".join(tags)
@staticmethod
def parse(lines: list[str]) -> 'ImageCaption':
def parse(string: str) -> 'ImageCaption':
"""
Parses a string to get the caption.
:param string: String to parse.
:return: `ImageCaption` object.
"""
main_prompts = []
tags = []
for line in lines:
split_caption = list(map(str.strip, line.split(",")))
main_prompts.append(split_caption[0])
tags.extend(split_caption[1:])
split_caption = list(map(str.strip, string.split(",")))
main_prompt = split_caption[0]
tags = split_caption[1:]
tag_weights = [1.0] * len(tags)
return ImageCaption(main_prompts, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_file_name(file_path: str) -> 'ImageCaption':
@ -160,7 +152,7 @@ class ImageCaption:
"""
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = [line.rstrip() for line in caption_file]
caption_text = caption_file.read()
return ImageCaption.parse(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption")
@ -282,7 +274,7 @@ class ImageTrainItem:
self.error = None
self.__compute_target_width_height()
def hydrate(self, crop=False, crop_jitter=20):
def hydrate(self, crop=False, save=False, crop_jitter=20):
"""
crop: hard center crop to 512x512
save: save the cropped image to disk, for manual inspection of resize/crop
@ -344,8 +336,18 @@ class ImageTrainItem:
exit()
if type(self.image) is not np.ndarray:
if save:
base_name = os.path.basename(self.pathname)
if not os.path.exists("test/output"):
os.makedirs("test/output")
self.image.save(f"test/output/{base_name}")
self.image = np.array(self.image).astype(np.uint8)
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
# print(self.image.shape)
return self
def __compute_target_width_height(self):