warn on chronically underfilled aspect ratio buckets
This commit is contained in:
parent
974d3fa53a
commit
ce85ce30ae
|
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
"""
|
||||
Notes:
|
||||
|
@ -237,4 +238,40 @@ def __get_all_aspects():
|
|||
ASPECTS_1152,
|
||||
ASPECTS_1280,
|
||||
ASPECTS_1536,
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
|
||||
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
||||
if x <= 1:
|
||||
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
||||
else:
|
||||
b,a = farey_aspect_ratio_pair_lt1(1/x, max_denominator_value)
|
||||
return a,b
|
||||
|
||||
# adapted from https://www.johndcook.com/blog/2010/10/20/best-rational-approximation/
|
||||
def farey_aspect_ratio_pair_lt1(x: float, max_denominator_value: int):
|
||||
if x > 1:
|
||||
raise ValueError("x must be <1")
|
||||
a, b = 0, 1
|
||||
c, d = 1, 1
|
||||
while (b <= max_denominator_value and d <= max_denominator_value):
|
||||
mediant = float(a+c)/(b+d)
|
||||
if x == mediant:
|
||||
if b + d <= max_denominator_value:
|
||||
return a+c, b+d
|
||||
elif d > b:
|
||||
return c, d
|
||||
else:
|
||||
return a, b
|
||||
elif x > mediant:
|
||||
a, b = a+c, b+d
|
||||
else:
|
||||
c, d = a+c, b+d
|
||||
|
||||
if (b > max_denominator_value):
|
||||
return c, d
|
||||
else:
|
||||
return a, b
|
||||
|
||||
return farey_aspect_ratio_pair(bucket_wh[0]/bucket_wh[1], 32)
|
||||
|
|
|
@ -21,6 +21,9 @@ import math
|
|||
import copy
|
||||
|
||||
import random
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import PIL.Image
|
||||
|
||||
|
@ -117,7 +120,9 @@ class DataLoaderMultiAspect():
|
|||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
appended_dupes = 0
|
||||
while len(runt_bucket) < batch_size:
|
||||
appended_dupes += 1
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
|
@ -174,4 +179,5 @@ class DataLoaderMultiAspect():
|
|||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
|
||||
|
|
31
train.py
31
train.py
|
@ -28,6 +28,7 @@ import random
|
|||
import traceback
|
||||
import shutil
|
||||
import importlib
|
||||
from collections import defaultdict
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
@ -305,7 +306,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
|||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None:
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
|
@ -323,12 +324,36 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
|
|||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
|
||||
# warn on underfilled aspect ratio buckets
|
||||
|
||||
# Intuition: if there are too few images to fill a batch, duplicates will be appended.
|
||||
# this is not a problem for large image counts but can seriously distort training if there
|
||||
# are just a handful of images for a given aspect ratio.
|
||||
|
||||
# at a dupe ratio of 0.5, all images in this bucket have effective multiplier 1.5,
|
||||
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
||||
warn_bucket_dupe_ratio = 0.5
|
||||
|
||||
ar_buckets = set([tuple(i.target_wh) for i in items])
|
||||
for ar_bucket in ar_buckets:
|
||||
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
|
||||
runt_size = batch_size - (count % batch_size)
|
||||
bucket_dupe_ratio = runt_size / count
|
||||
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
|
||||
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
||||
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
||||
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
||||
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
||||
f"of {effective_multiplier}, which may cause problems. Consider adding up to {runt_size} "
|
||||
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
resolved_items = resolver.resolve(args.data_root, args)
|
||||
report_image_train_item_problems(log_folder, resolved_items)
|
||||
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
||||
|
||||
# Remove erroneous items
|
||||
|
@ -616,6 +641,8 @@ def main(args):
|
|||
# the validation dataset may need to steal some items from image_train_items
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
|
|
Loading…
Reference in New Issue