early work on shuffle_tags.txt and add try around trimming
This commit is contained in:
parent
7dcfa7acbf
commit
56deb26a59
|
@ -55,6 +55,7 @@ class ImageConfig:
|
|||
multiply: float = None
|
||||
cond_dropout: float = None
|
||||
flip_p: float = None
|
||||
shuffle_tags: bool = False
|
||||
|
||||
def merge(self, other):
|
||||
if other is None:
|
||||
|
@ -68,6 +69,7 @@ class ImageConfig:
|
|||
multiply=overlay(other.multiply, self.multiply),
|
||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -81,6 +83,7 @@ class ImageConfig:
|
|||
multiply=data.get("multiply"),
|
||||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"),
|
||||
shuffle_tags=data.get("shuffle_tags"),
|
||||
)
|
||||
|
||||
# Alternatively parse from dedicated `caption` attribute
|
||||
|
@ -94,6 +97,9 @@ class ImageConfig:
|
|||
acc = ImageConfig()
|
||||
for cfg in configs:
|
||||
acc = acc.merge(cfg)
|
||||
|
||||
acc.shuffle_tags = any(cfg.shuffle_tags for cfg in configs)
|
||||
print(f"accum shuffle:{acc.shuffle_tags}")
|
||||
return acc
|
||||
|
||||
def ensure_caption(self):
|
||||
|
@ -151,6 +157,7 @@ class Dataset:
|
|||
|
||||
def __local_cfg(fileset):
|
||||
cfgs = []
|
||||
|
||||
if 'multiply.txt' in fileset:
|
||||
cfgs.append(ImageConfig(multiply=read_float(fileset['multiply.txt'])))
|
||||
if 'cond_dropout.txt' in fileset:
|
||||
|
@ -161,7 +168,12 @@ class Dataset:
|
|||
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
||||
if 'local.yml' in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
result = ImageConfig.fold(cfgs)
|
||||
if 'shuffle_tags.txt' in fileset:
|
||||
result.shuffle_tags = True
|
||||
|
||||
return result
|
||||
|
||||
def __sidecar_cfg(imagepath, fileset):
|
||||
cfgs = []
|
||||
|
@ -195,6 +207,8 @@ class Dataset:
|
|||
return global_cfg
|
||||
|
||||
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||
for img in image_configs:
|
||||
print(f" *** {img}: {image_configs[img]}")
|
||||
return Dataset(image_configs)
|
||||
|
||||
@classmethod
|
||||
|
@ -217,6 +231,7 @@ class Dataset:
|
|||
items = []
|
||||
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||
config = self.image_configs[image]
|
||||
print(f" ********* shuffle: {config.shuffle_tags}")
|
||||
|
||||
if len(config.main_prompts) > 1:
|
||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||
|
@ -247,7 +262,8 @@ class Dataset:
|
|||
pathname=os.path.abspath(image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
cond_dropout=config.cond_dropout,
|
||||
shuffle_tags=config.shuffle_tags,
|
||||
)
|
||||
items.append(item)
|
||||
except Exception as e:
|
||||
|
|
|
@ -40,7 +40,6 @@ class EveryDreamBatch(Dataset):
|
|||
crop_jitter=0.02,
|
||||
seed=555,
|
||||
tokenizer=None,
|
||||
retain_contrast=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
|
@ -54,7 +53,6 @@ class EveryDreamBatch(Dataset):
|
|||
self.unloaded_to_idx = 0
|
||||
self.tokenizer = tokenizer
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.retain_contrast = retain_contrast
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
|
@ -85,12 +83,8 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
|
||||
|
||||
if self.retain_contrast:
|
||||
std_dev = 1.0
|
||||
mean = 0.0
|
||||
else:
|
||||
std_dev = 0.5
|
||||
mean = 0.5
|
||||
std_dev = 0.5
|
||||
mean = 0.5
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
|
@ -99,7 +93,7 @@ class EveryDreamBatch(Dataset):
|
|||
]
|
||||
)
|
||||
|
||||
if self.shuffle_tags:
|
||||
if self.shuffle_tags or train_item["shuffle_tags"]:
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||
else:
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
@ -137,6 +131,7 @@ class EveryDreamBatch(Dataset):
|
|||
if image_train_tmp.cond_dropout is not None:
|
||||
example["cond_dropout"] = image_train_tmp.cond_dropout
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
example["shuffle_tags"] = image_train_tmp.shuffle_tags
|
||||
|
||||
return example
|
||||
|
||||
|
|
|
@ -124,7 +124,16 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
|
||||
def __init__(self,
|
||||
image: PIL.Image,
|
||||
caption: ImageCaption,
|
||||
aspects: list[float],
|
||||
pathname: str,
|
||||
flip_p=0.0,
|
||||
multiplier: float=1.0,
|
||||
cond_dropout=None,
|
||||
shuffle_tags=False,
|
||||
):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
|
@ -133,6 +142,7 @@ class ImageTrainItem:
|
|||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
self.shuffle_tags = shuffle_tags
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
@ -197,14 +207,12 @@ class ImageTrainItem:
|
|||
top_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
bottom_crop_pixels = random.uniform(0, max_crop_pixels)
|
||||
|
||||
# Calculate the cropping coordinates
|
||||
left = left_crop_pixels
|
||||
right = width - right_crop_pixels
|
||||
top = top_crop_pixels
|
||||
bottom = height - bottom_crop_pixels
|
||||
#print(f"\n *** jitter l: {left}, t: {top}, r: {right}, b: {bottom}, orig w: {width}, h: {height}, max_crop_pixels: {max_crop_pixels}")
|
||||
|
||||
# Crop the image
|
||||
cropped_image = image.crop((left, top, right, bottom))
|
||||
|
||||
cropped_width = width - int(left_crop_pixels + right_crop_pixels)
|
||||
|
@ -212,7 +220,6 @@ class ImageTrainItem:
|
|||
|
||||
cropped_aspect_ratio = cropped_width / cropped_height
|
||||
|
||||
# Resize the cropped image to maintain square pixels
|
||||
if cropped_aspect_ratio > 1:
|
||||
new_width = cropped_width
|
||||
new_height = int(cropped_width / cropped_aspect_ratio)
|
||||
|
@ -220,7 +227,6 @@ class ImageTrainItem:
|
|||
new_width = int(cropped_height * cropped_aspect_ratio)
|
||||
new_height = cropped_height
|
||||
|
||||
#print(f" *** postsquarefix new w: {new_width}, h: {new_height}")
|
||||
cropped_image = cropped_image.resize((new_width, new_height))
|
||||
|
||||
return cropped_image
|
||||
|
@ -241,34 +247,33 @@ class ImageTrainItem:
|
|||
pass
|
||||
|
||||
def _trim_to_aspect(self, image, target_wh):
|
||||
width, height = image.size
|
||||
target_aspect = target_wh[0] / target_wh[1] # 0.60
|
||||
image_aspect = width / height # 0.5865
|
||||
#self._debug_save_image(image, "precrop")
|
||||
if image_aspect > target_aspect:
|
||||
target_width = int(height * target_aspect)
|
||||
overwidth = width - target_width
|
||||
l = random.normalvariate(overwidth/2, overwidth/2)
|
||||
l = max(0, l)
|
||||
l = min(l, overwidth)
|
||||
r = width - int(overwidth) - l
|
||||
image = image.crop((l, 0, r, height))
|
||||
elif target_aspect > image_aspect:
|
||||
target_height = int(width / target_aspect)
|
||||
overheight = height - target_height
|
||||
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
|
||||
try:
|
||||
width, height = image.size
|
||||
target_aspect = target_wh[0] / target_wh[1] # 0.60
|
||||
image_aspect = width / height # 0.5865
|
||||
#self._debug_save_image(image, "precrop")
|
||||
if image_aspect > target_aspect:
|
||||
target_width = int(height * target_aspect)
|
||||
overwidth = width - target_width
|
||||
l = random.normalvariate(overwidth/2, overwidth/2)
|
||||
l = max(0, l)
|
||||
l = min(l, overwidth)
|
||||
r = width - int(overwidth) - l
|
||||
image = image.crop((l, 0, r, height))
|
||||
elif target_aspect > image_aspect:
|
||||
target_height = int(width / target_aspect)
|
||||
overheight = height - target_height
|
||||
image = image.crop((0, int(overheight/2), width, height-int(overheight/2)))
|
||||
except Exception as e:
|
||||
print(f"error trimming image {self.pathname}: {e}")
|
||||
pass
|
||||
|
||||
def hydrate(self, save=False, crop_jitter=0.02):
|
||||
"""
|
||||
save: save the cropped image to disk, for manual inspection of resize/crop
|
||||
"""
|
||||
# print(self.pathname, self.image)
|
||||
# try:
|
||||
# if not hasattr(self, 'image'):
|
||||
image = self.load_image()
|
||||
|
||||
#print(f"** jittering: {self.pathname}")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
img_jitter = min((width-self.target_wh[0])/self.target_wh[0], (height-self.target_wh[1])/self.target_wh[1])
|
||||
|
@ -283,6 +288,7 @@ class ImageTrainItem:
|
|||
self.image = image.resize(self.target_wh)
|
||||
|
||||
self.image = self.flip(self.image)
|
||||
# Remove comment here to view image cropping outputs
|
||||
# self._debug_save_image(self.image, "final")
|
||||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
@ -293,8 +299,7 @@ class ImageTrainItem:
|
|||
self.target_wh = None
|
||||
try:
|
||||
with PIL.Image.open(self.pathname) as image:
|
||||
needs_transpose = self._needs_transpose(image)
|
||||
if needs_transpose:
|
||||
if self._needs_transpose(image):
|
||||
height, width = image.size
|
||||
else:
|
||||
width, height = image.size
|
||||
|
|
Loading…
Reference in New Issue