warn on chronically underfilled aspect ratio buckets

This commit is contained in:
Damian Stewart 2023-04-19 11:06:02 +02:00 committed by Victor Hall
parent 974d3fa53a
commit ce85ce30ae
3 changed files with 74 additions and 4 deletions

View File

@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List, Tuple
"""
Notes:
@ -237,4 +238,40 @@ def __get_all_aspects():
ASPECTS_1152,
ASPECTS_1280,
ASPECTS_1536,
]
]
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
else:
b,a = farey_aspect_ratio_pair_lt1(1/x, max_denominator_value)
return a,b
# adapted from https://www.johndcook.com/blog/2010/10/20/best-rational-approximation/
def farey_aspect_ratio_pair_lt1(x: float, max_denominator_value: int):
if x > 1:
raise ValueError("x must be <1")
a, b = 0, 1
c, d = 1, 1
while (b <= max_denominator_value and d <= max_denominator_value):
mediant = float(a+c)/(b+d)
if x == mediant:
if b + d <= max_denominator_value:
return a+c, b+d
elif d > b:
return c, d
else:
return a, b
elif x > mediant:
a, b = a+c, b+d
else:
c, d = a+c, b+d
if (b > max_denominator_value):
return c, d
else:
return a, b
return farey_aspect_ratio_pair(bucket_wh[0]/bucket_wh[1], 32)

View File

@ -21,6 +21,9 @@ import math
import copy
import random
from colorama import Fore, Style
from data.image_train_item import ImageTrainItem
import PIL.Image
@ -117,7 +120,9 @@ class DataLoaderMultiAspect():
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
appended_dupes = 0
while len(runt_bucket) < batch_size:
appended_dupes += 1
runt_bucket.append(random.choice(runt_bucket))
current_bucket_size = len(buckets[bucket])
@ -174,4 +179,5 @@ class DataLoaderMultiAspect():
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.ratings_summed.append(self.rating_overall_sum)

View File

@ -28,6 +28,7 @@ import random
import traceback
import shutil
import importlib
from collections import defaultdict
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
@ -305,7 +306,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None:
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
@ -323,12 +324,36 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
undersized_images_file.write(message)
# warn on underfilled aspect ratio buckets
# Intuition: if there are too few images to fill a batch, duplicates will be appended.
# this is not a problem for large image counts but can seriously distort training if there
# are just a handful of images for a given aspect ratio.
# at a dupe ratio of 0.5, all images in this bucket have effective multiplier 1.5,
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items])
for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding up to {runt_size} "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
logging.info(" Preloading images...")
resolved_items = resolver.resolve(args.data_root, args)
report_image_train_item_problems(log_folder, resolved_items)
image_paths = set(map(lambda item: item.pathname, resolved_items))
# Remove erroneous items
@ -616,6 +641,8 @@ def main(args):
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
seed=seed,