add a batch_id.txt file to subfolders or a batch_id key to local yaml to force images for that folder to be processed in the same batch

This commit is contained in:
Damian Stewart 2023-06-05 01:01:59 +02:00
parent 7e09b6dc29
commit 53d0686086
5 changed files with 51 additions and 18 deletions

View File

@ -248,7 +248,7 @@ def __get_all_aspects():
]
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)

View File

@ -21,7 +21,10 @@ import math
import copy
import random
from data.image_train_item import ImageTrainItem
from itertools import groupby
from typing import Tuple, List
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -103,15 +106,29 @@ class DataLoaderMultiAspect():
buckets = {}
batch_size = self.batch_size
def append_to_bucket(item, batch_id):
bucket_key = (batch_id, item.target_wh[0], item.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(item)
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
batch_id = image_caption_pair.batch_id
append_to_bucket(image_caption_pair, batch_id)
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
# shunt any runts from "named" buckets into the appropriate "general" buckets
for bucket in [b for b in buckets if b[0] != DEFAULT_BATCH_ID]:
truncate_count = len(buckets[bucket]) % batch_size
for runt in buckets[bucket][-truncate_count:]:
append_to_bucket(runt, DEFAULT_BATCH_ID)
del buckets[bucket][-truncate_count:]
if len(buckets[bucket]) == 0:
del buckets[bucket]
for bucket in buckets:
# handle runts in "general" buckets by randomly duplicating items
for bucket in [b for b in buckets if b[0] == DEFAULT_BATCH_ID]:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
@ -132,6 +149,7 @@ class DataLoaderMultiAspect():
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images
@ -174,4 +192,6 @@ class DataLoaderMultiAspect():
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.ratings_summed.append(self.rating_overall_sum)

View File

@ -1,10 +1,7 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from attrs import define, field
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import Iterable
@ -50,6 +47,7 @@ class ImageConfig:
rating: float = None
max_caption_length: int = None
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
batch_id: str = None
# Options
multiply: float = None
@ -70,6 +68,7 @@ class ImageConfig:
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
batch_id=overlay(other.batch_id, self.batch_id)
)
@classmethod
@ -84,6 +83,7 @@ class ImageConfig:
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"),
batch_id=data.get("batch_id")
)
# Alternatively parse from dedicated `caption` attribute
@ -168,6 +168,8 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
if 'local.yml' in fileset:
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
if 'batch_id.txt' in fileset:
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset:
@ -262,6 +264,7 @@ class Dataset:
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags,
batch_id=config.batch_id
)
items.append(item)
except Exception as e:

View File

@ -124,7 +124,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self,
def __init__(self,
image: PIL.Image,
caption: ImageCaption,
aspects: list[float],
@ -133,6 +133,7 @@ class ImageTrainItem:
multiplier: float=1.0,
cond_dropout=None,
shuffle_tags=False,
batch_id: str=None
):
self.caption = caption
self.aspects = aspects
@ -143,6 +144,8 @@ class ImageTrainItem:
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags
self.batch_id = batch_id or DEFAULT_BATCH_ID
self.target_wh = None
self.image_size = None
if image is None or len(image) == 0:
@ -351,3 +354,6 @@ class ImageTrainItem:
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image
DEFAULT_BATCH_ID = "default_batch"

View File

@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items])
def make_bucket_key(item):
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
ar_buckets = set(make_bucket_key(i) for i in items)
for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")