early work on shuffle_tags.txt and add try around trimming

This commit is contained in:
Victor Hall 2023-06-01 16:18:21 -04:00
parent 7dcfa7acbf
commit 56deb26a59
3 changed files with 55 additions and 39 deletions

View File

@ -55,6 +55,7 @@ class ImageConfig:
multiply: float = None
cond_dropout: float = None
flip_p: float = None
shuffle_tags: bool = False
def merge(self, other):
if other is None:
@ -68,6 +69,7 @@ class ImageConfig:
multiply=overlay(other.multiply, self.multiply),
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
)
@classmethod
@ -81,6 +83,7 @@ class ImageConfig:
multiply=data.get("multiply"),
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"),
)
# Alternatively parse from dedicated `caption` attribute
@ -94,6 +97,9 @@ class ImageConfig:
acc = ImageConfig()
for cfg in configs:
acc = acc.merge(cfg)
acc.shuffle_tags = any(cfg.shuffle_tags for cfg in configs)
print(f"accum shuffle:{acc.shuffle_tags}")
return acc
def ensure_caption(self):
@ -151,6 +157,7 @@ class Dataset:
def __local_cfg(fileset):
cfgs = []
if 'multiply.txt' in fileset:
cfgs.append(ImageConfig(multiply=read_float(fileset['multiply.txt'])))
if 'cond_dropout.txt' in fileset:
@ -161,7 +168,12 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
if 'local.yml' in fileset:
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
return ImageConfig.fold(cfgs)
result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset:
result.shuffle_tags = True
return result
def __sidecar_cfg(imagepath, fileset):
cfgs = []
@ -195,6 +207,8 @@ class Dataset:
return global_cfg
walk_and_visit(data_root, process_dir, ImageConfig())
for img in image_configs:
print(f" *** {img}: {image_configs[img]}")
return Dataset(image_configs)
@classmethod
@ -217,6 +231,7 @@ class Dataset:
items = []
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
config = self.image_configs[image]
print(f" ********* shuffle: {config.shuffle_tags}")
if len(config.main_prompts) > 1:
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
@ -247,7 +262,8 @@ class Dataset:
pathname=os.path.abspath(image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags,
)
items.append(item)
except Exception as e:

View File

@ -40,7 +40,6 @@ class EveryDreamBatch(Dataset):
crop_jitter=0.02,
seed=555,
tokenizer=None,
retain_contrast=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
@ -54,7 +53,6 @@ class EveryDreamBatch(Dataset):
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
self.max_token_length = self.tokenizer.model_max_length
self.retain_contrast = retain_contrast
self.shuffle_tags = shuffle_tags
self.seed = seed
self.rated_dataset = rated_dataset
@ -85,12 +83,8 @@ class EveryDreamBatch(Dataset):
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
if self.retain_contrast:
std_dev = 1.0
mean = 0.0
else:
std_dev = 0.5
mean = 0.5
std_dev = 0.5
mean = 0.5
image_transforms = transforms.Compose(
[
@ -99,7 +93,7 @@ class EveryDreamBatch(Dataset):
]
)
if self.shuffle_tags:
if self.shuffle_tags or train_item["shuffle_tags"]:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
else:
example["caption"] = train_item["caption"].get_caption()
@ -137,6 +131,7 @@ class EveryDreamBatch(Dataset):
if image_train_tmp.cond_dropout is not None:
example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size
example["shuffle_tags"] = image_train_tmp.shuffle_tags
return example

View File

@ -124,7 +124,16 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
def __init__(self,
image: PIL.Image,
caption: ImageCaption,
aspects: list[float],
pathname: str,
flip_p=0.0,
multiplier: float=1.0,
cond_dropout=None,
shuffle_tags=False,
):
self.caption = caption
self.aspects = aspects
self.pathname = pathname
@ -133,6 +142,7 @@ class ImageTrainItem:
self.runt_size = 0
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags
self.image_size = None
if image is None or len(image) == 0:
@ -197,14 +207,12 @@ class ImageTrainItem:
top_crop_pixels = random.uniform(0, max_crop_pixels)
bottom_crop_pixels = random.uniform(0, max_crop_pixels)
# Calculate the cropping coordinates
left = left_crop_pixels
right = width - right_crop_pixels
top = top_crop_pixels
bottom = height - bottom_crop_pixels
#print(f"\n *** jitter l: {left}, t: {top}, r: {right}, b: {bottom}, orig w: {width}, h: {height}, max_crop_pixels: {max_crop_pixels}")
# Crop the image
cropped_image = image.crop((left, top, right, bottom))
cropped_width = width - int(left_crop_pixels + right_crop_pixels)
@ -212,7 +220,6 @@ class ImageTrainItem:
cropped_aspect_ratio = cropped_width / cropped_height
# Resize the cropped image to maintain square pixels
if cropped_aspect_ratio > 1:
new_width = cropped_width
new_height = int(cropped_width / cropped_aspect_ratio)
@ -220,7 +227,6 @@ class ImageTrainItem:
new_width = int(cropped_height * cropped_aspect_ratio)
new_height = cropped_height
#print(f" *** postsquarefix new w: {new_width}, h: {new_height}")
cropped_image = cropped_image.resize((new_width, new_height))
return cropped_image
@ -241,34 +247,33 @@ class ImageTrainItem:
pass
def _trim_to_aspect(self, image, target_wh):
width, height = image.size
target_aspect = target_wh[0] / target_wh[1] # 0.60
image_aspect = width / height # 0.5865
#self._debug_save_image(image, "precrop")
if image_aspect > target_aspect:
target_width = int(height * target_aspect)
overwidth = width - target_width
l = random.normalvariate(overwidth/2, overwidth/2)
l = max(0, l)
l = min(l, overwidth)
r = width - int(overwidth) - l
image = image.crop((l, 0, r, height))
elif target_aspect > image_aspect:
target_height = int(width / target_aspect)
overheight = height - target_height
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
try:
width, height = image.size
target_aspect = target_wh[0] / target_wh[1] # 0.60
image_aspect = width / height # 0.5865
#self._debug_save_image(image, "precrop")
if image_aspect > target_aspect:
target_width = int(height * target_aspect)
overwidth = width - target_width
l = random.normalvariate(overwidth/2, overwidth/2)
l = max(0, l)
l = min(l, overwidth)
r = width - int(overwidth) - l
image = image.crop((l, 0, r, height))
elif target_aspect > image_aspect:
target_height = int(width / target_aspect)
overheight = height - target_height
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
except Exception as e:
print(f"error trimming image {self.pathname}: {e}")
pass
def hydrate(self, save=False, crop_jitter=0.02):
"""
save: save the cropped image to disk, for manual inspection of resize/crop
"""
# print(self.pathname, self.image)
# try:
# if not hasattr(self, 'image'):
image = self.load_image()
#print(f"** jittering: {self.pathname}")
width, height = image.size
img_jitter = min((width-self.target_wh[0])/self.target_wh[0], (height-self.target_wh[1])/self.target_wh[1])
@ -283,6 +288,7 @@ class ImageTrainItem:
self.image = image.resize(self.target_wh)
self.image = self.flip(self.image)
# Remove comment here to view image cropping outputs
# self._debug_save_image(self.image, "final")
self.image = np.array(self.image).astype(np.uint8)
@ -293,8 +299,7 @@ class ImageTrainItem:
self.target_wh = None
try:
with PIL.Image.open(self.pathname) as image:
needs_transpose = self._needs_transpose(image)
if needs_transpose:
if self._needs_transpose(image):
height, width = image.size
else:
width, height = image.size