revert multiline txt for now due to bug
This commit is contained in:
parent
c446194599
commit
8abef6bc74
|
@ -26,6 +26,7 @@ import torch.nn.functional as F
|
|||
class EveryDreamBatch(Dataset):
|
||||
"""
|
||||
data_loader: `DataLoaderMultiAspect` object
|
||||
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
||||
conditional_dropout: probability of dropping the caption for a given image
|
||||
crop_jitter: number of pixels to jitter the crop by, only for non-square images
|
||||
seed: random seed
|
||||
|
@ -37,6 +38,7 @@ class EveryDreamBatch(Dataset):
|
|||
crop_jitter=20,
|
||||
seed=555,
|
||||
tokenizer=None,
|
||||
retain_contrast=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
|
@ -50,6 +52,7 @@ class EveryDreamBatch(Dataset):
|
|||
self.unloaded_to_idx = 0
|
||||
self.tokenizer = tokenizer
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.retain_contrast = retain_contrast
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
|
@ -78,19 +81,26 @@ class EveryDreamBatch(Dataset):
|
|||
def __getitem__(self, i):
|
||||
example = {}
|
||||
|
||||
train_item = self.__get_image_for_trainer(self.image_train_items[i])
|
||||
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
|
||||
|
||||
if self.retain_contrast:
|
||||
std_dev = 1.0
|
||||
mean = 0.0
|
||||
else:
|
||||
std_dev = 0.5
|
||||
mean = 0.5
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.Normalize([mean], [std_dev]),
|
||||
]
|
||||
)
|
||||
|
||||
if self.shuffle_tags:
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||
else:
|
||||
example["caption"] = train_item["caption"].get_caption(self.seed)
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
|
@ -113,10 +123,11 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
return example
|
||||
|
||||
def __get_image_for_trainer(self, image_train_item: ImageTrainItem):
|
||||
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
|
||||
example = {}
|
||||
save = debug_level > 2
|
||||
|
||||
image_train_tmp = image_train_item.hydrate(crop=False, crop_jitter=self.crop_jitter)
|
||||
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
||||
|
||||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
|
|
|
@ -37,7 +37,7 @@ class ImageCaption:
|
|||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompts: list[str], rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
:param tags: list of tags to pick from to fill the caption
|
||||
|
@ -45,7 +45,7 @@ class ImageCaption:
|
|||
:param max_target_length: The desired maximum length of a generated caption
|
||||
:param use_weights: if ture, weights are considered when shuffling tags
|
||||
"""
|
||||
self.__main_prompts = main_prompts
|
||||
self.__main_prompt = main_prompt
|
||||
self.__rating = rating
|
||||
self.__tags = tags
|
||||
self.__tag_weights = tag_weights
|
||||
|
@ -66,25 +66,21 @@ class ImageCaption:
|
|||
:param seed used to initialize the randomizer
|
||||
:return: generated caption string
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
main_prompt = rng.choice(self.__main_prompts)
|
||||
if self.__tags:
|
||||
max_target_tag_length = self.__max_target_length - len(main_prompt)
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
|
||||
return main_prompt + ", " + tags_caption
|
||||
return main_prompt
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt
|
||||
|
||||
def get_caption(self, seed) -> str:
|
||||
rng = random.Random(seed)
|
||||
main_prompt = rng.choice(self.__main_prompts)
|
||||
if self.__tags:
|
||||
return main_prompt + ", " + ", ".join(self.__tags)
|
||||
return main_prompt
|
||||
def get_caption(self) -> str:
|
||||
if self.__tags:
|
||||
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
||||
return self.__main_prompt
|
||||
|
||||
@staticmethod
|
||||
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
|
||||
|
@ -118,23 +114,19 @@ class ImageCaption:
|
|||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines: list[str]) -> 'ImageCaption':
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
main_prompts = []
|
||||
tags = []
|
||||
for line in lines:
|
||||
split_caption = list(map(str.strip, line.split(",")))
|
||||
main_prompts.append(split_caption[0])
|
||||
tags.extend(split_caption[1:])
|
||||
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompts, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
|
@ -160,7 +152,7 @@ class ImageCaption:
|
|||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = [line.rstrip() for line in caption_file]
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
|
@ -282,7 +274,7 @@ class ImageTrainItem:
|
|||
self.error = None
|
||||
self.__compute_target_width_height()
|
||||
|
||||
def hydrate(self, crop=False, crop_jitter=20):
|
||||
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
||||
"""
|
||||
crop: hard center crop to 512x512
|
||||
save: save the cropped image to disk, for manual inspection of resize/crop
|
||||
|
@ -344,8 +336,18 @@ class ImageTrainItem:
|
|||
exit()
|
||||
|
||||
if type(self.image) is not np.ndarray:
|
||||
if save:
|
||||
base_name = os.path.basename(self.pathname)
|
||||
if not os.path.exists("test/output"):
|
||||
os.makedirs("test/output")
|
||||
self.image.save(f"test/output/{base_name}")
|
||||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
# print(self.image.shape)
|
||||
|
||||
return self
|
||||
|
||||
def __compute_target_width_height(self):
|
||||
|
|
Loading…
Reference in New Issue