warn on chronically underfilled aspect ratio buckets

This commit is contained in:
Damian Stewart 2023-04-19 11:06:02 +02:00 committed by Victor Hall
parent 974d3fa53a
commit ce85ce30ae
3 changed files with 74 additions and 4 deletions

View File

@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
from typing import List, Tuple
@ -237,4 +238,40 @@ def __get_all_aspects():
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
b,a = farey_aspect_ratio_pair_lt1(1/x, max_denominator_value)
return a,b
# adapted from https://www.johndcook.com/blog/2010/10/20/best-rational-approximation/
def farey_aspect_ratio_pair_lt1(x: float, max_denominator_value: int):
if x > 1:
raise ValueError("x must be <1")
a, b = 0, 1
c, d = 1, 1
while (b <= max_denominator_value and d <= max_denominator_value):
mediant = float(a+c)/(b+d)
if x == mediant:
if b + d <= max_denominator_value:
return a+c, b+d
elif d > b:
return c, d
return a, b
elif x > mediant:
a, b = a+c, b+d
c, d = a+c, b+d
if (b > max_denominator_value):
return c, d
return a, b
return farey_aspect_ratio_pair(bucket_wh[0]/bucket_wh[1], 32)

View File

@ -21,6 +21,9 @@ import math
import copy
import random
from colorama import Fore, Style
from data.image_train_item import ImageTrainItem
import PIL.Image
@ -117,7 +120,9 @@ class DataLoaderMultiAspect():
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
appended_dupes = 0
while len(runt_bucket) < batch_size:
appended_dupes += 1
current_bucket_size = len(buckets[bucket])
@ -174,4 +179,5 @@ class DataLoaderMultiAspect():
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()

View File

@ -28,6 +28,7 @@ import random
import traceback
import shutil
import importlib
from collections import defaultdict
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
@ -305,7 +306,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None:
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
@ -323,12 +324,36 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
# warn on underfilled aspect ratio buckets
# Intuition: if there are too few images to fill a batch, duplicates will be appended.
# this is not a problem for large image counts but can seriously distort training if there
# are just a handful of images for a given aspect ratio.
# at a dupe ratio of 0.5, all images in this bucket have effective multiplier 1.5,
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items])
for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding up to {runt_size} "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
logging.info(" Preloading images...")
resolved_items = resolver.resolve(args.data_root, args)
report_image_train_item_problems(log_folder, resolved_items)
image_paths = set(map(lambda item: item.pathname, resolved_items))
# Remove erroneous items
@ -616,6 +641,8 @@ def main(args):
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
data_loader = DataLoaderMultiAspect(