add a batch_id.txt file to subfolders or a batch_id key to local yaml to force images for that folder to be processed in the same batch

This commit is contained in:
Damian Stewart 2023-06-05 01:01:59 +02:00
parent 7e09b6dc29
commit 53d0686086
5 changed files with 51 additions and 18 deletions

View File

@ -248,7 +248,7 @@ def __get_all_aspects():
] ]
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]: def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int): def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1: if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value) return farey_aspect_ratio_pair_lt1(x, max_denominator_value)

View File

@ -21,7 +21,10 @@ import math
import copy import copy
import random import random
from data.image_train_item import ImageTrainItem from itertools import groupby
from typing import Tuple, List
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -103,15 +106,29 @@ class DataLoaderMultiAspect():
buckets = {} buckets = {}
batch_size = self.batch_size batch_size = self.batch_size
def append_to_bucket(item, batch_id):
bucket_key = (batch_id, item.target_wh[0], item.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(item)
for image_caption_pair in picked_images: for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0 image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh batch_id = image_caption_pair.batch_id
append_to_bucket(image_caption_pair, batch_id)
if (target_wh[0],target_wh[1]) not in buckets: # shunt any runts from "named" buckets into the appropriate "general" buckets
buckets[(target_wh[0],target_wh[1])] = [] for bucket in [b for b in buckets if b[0] != DEFAULT_BATCH_ID]:
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) truncate_count = len(buckets[bucket]) % batch_size
for runt in buckets[bucket][-truncate_count:]:
append_to_bucket(runt, DEFAULT_BATCH_ID)
del buckets[bucket][-truncate_count:]
if len(buckets[bucket]) == 0:
del buckets[bucket]
for bucket in buckets: # handle runts in "general" buckets by randomly duplicating items
for bucket in [b for b in buckets if b[0] == DEFAULT_BATCH_ID]:
truncate_count = len(buckets[bucket]) % batch_size truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0: if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:] runt_bucket = buckets[bucket][-truncate_count:]
@ -132,6 +149,7 @@ class DataLoaderMultiAspect():
return items return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
""" """
Picks a random subset of all images Picks a random subset of all images
@ -175,3 +193,5 @@ class DataLoaderMultiAspect():
for item in self.prepared_train_data: for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating() self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum) self.ratings_summed.append(self.rating_overall_sum)

View File

@ -1,10 +1,7 @@
import os
import logging
import yaml import yaml
import json import json
from functools import total_ordering from attrs import define, field
from attrs import define, field, Factory
from data.image_train_item import ImageCaption, ImageTrainItem from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import * from utils.fs_helpers import *
from typing import Iterable from typing import Iterable
@ -50,6 +47,7 @@ class ImageConfig:
rating: float = None rating: float = None
max_caption_length: int = None max_caption_length: int = None
tags: dict[Tag, None] = field(factory=dict, converter=safe_set) tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
batch_id: str = None
# Options # Options
multiply: float = None multiply: float = None
@ -70,6 +68,7 @@ class ImageConfig:
cond_dropout=overlay(other.cond_dropout, self.cond_dropout), cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p), flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
batch_id=overlay(other.batch_id, self.batch_id)
) )
@classmethod @classmethod
@ -84,6 +83,7 @@ class ImageConfig:
cond_dropout=data.get("cond_dropout"), cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"), flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"), shuffle_tags=data.get("shuffle_tags"),
batch_id=data.get("batch_id")
) )
# Alternatively parse from dedicated `caption` attribute # Alternatively parse from dedicated `caption` attribute
@ -168,6 +168,8 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yaml'])) cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
if 'local.yml' in fileset: if 'local.yml' in fileset:
cfgs.append(ImageConfig.from_file(fileset['local.yml'])) cfgs.append(ImageConfig.from_file(fileset['local.yml']))
if 'batch_id.txt' in fileset:
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
result = ImageConfig.fold(cfgs) result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset: if 'shuffle_tags.txt' in fileset:
@ -262,6 +264,7 @@ class Dataset:
multiplier=config.multiply or 1.0, multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout, cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags, shuffle_tags=config.shuffle_tags,
batch_id=config.batch_id
) )
items.append(item) items.append(item)
except Exception as e: except Exception as e:

View File

@ -133,6 +133,7 @@ class ImageTrainItem:
multiplier: float=1.0, multiplier: float=1.0,
cond_dropout=None, cond_dropout=None,
shuffle_tags=False, shuffle_tags=False,
batch_id: str=None
): ):
self.caption = caption self.caption = caption
self.aspects = aspects self.aspects = aspects
@ -143,6 +144,8 @@ class ImageTrainItem:
self.multiplier = multiplier self.multiplier = multiplier
self.cond_dropout = cond_dropout self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.batch_id = batch_id or DEFAULT_BATCH_ID
self.target_wh = None
self.image_size = None self.image_size = None
if image is None or len(image) == 0: if image is None or len(image) == 0:
@ -351,3 +354,6 @@ class ImageTrainItem:
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy)) image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image return image
DEFAULT_BATCH_ID = "default_batch"

View File

@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch, build_torch_dataloader from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
from utils.huggingface_downloader import try_download_model_from_hf from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng from utils.isolate_rng import isolate_rng
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0 # at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5 warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items]) def make_bucket_key(item):
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
ar_buckets = set(make_bucket_key(i) for i in items)
for ar_bucket in ar_buckets: for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket]) count = len([i for i in items if make_bucket_key(i) == ar_bucket])
runt_size = batch_size - (count % batch_size) runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio: if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket) aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}" aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
effective_multiplier = round(1 + bucket_dupe_ratio, 1) effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} " logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier " f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or " f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.") f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]: def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")