add a batch_id.txt file to subfolders or a batch_id key to local yaml to force images for that folder to be processed in the same batch
This commit is contained in:
parent
7e09b6dc29
commit
53d0686086
|
@ -248,7 +248,7 @@ def __get_all_aspects():
|
|||
]
|
||||
|
||||
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
|
||||
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
||||
if x <= 1:
|
||||
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
||||
|
|
|
@ -21,7 +21,10 @@ import math
|
|||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from itertools import groupby
|
||||
from typing import Tuple, List
|
||||
|
||||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||
import PIL.Image
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
@ -103,15 +106,29 @@ class DataLoaderMultiAspect():
|
|||
|
||||
buckets = {}
|
||||
batch_size = self.batch_size
|
||||
|
||||
def append_to_bucket(item, batch_id):
|
||||
bucket_key = (batch_id, item.target_wh[0], item.target_wh[1])
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(item)
|
||||
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
target_wh = image_caption_pair.target_wh
|
||||
batch_id = image_caption_pair.batch_id
|
||||
append_to_bucket(image_caption_pair, batch_id)
|
||||
|
||||
if (target_wh[0],target_wh[1]) not in buckets:
|
||||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
# shunt any runts from "named" buckets into the appropriate "general" buckets
|
||||
for bucket in [b for b in buckets if b[0] != DEFAULT_BATCH_ID]:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
for runt in buckets[bucket][-truncate_count:]:
|
||||
append_to_bucket(runt, DEFAULT_BATCH_ID)
|
||||
del buckets[bucket][-truncate_count:]
|
||||
if len(buckets[bucket]) == 0:
|
||||
del buckets[bucket]
|
||||
|
||||
for bucket in buckets:
|
||||
# handle runts in "general" buckets by randomly duplicating items
|
||||
for bucket in [b for b in buckets if b[0] == DEFAULT_BATCH_ID]:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
|
@ -132,6 +149,7 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return items
|
||||
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Picks a random subset of all images
|
||||
|
@ -174,4 +192,6 @@ class DataLoaderMultiAspect():
|
|||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, field, Factory
|
||||
from attrs import define, field
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import Iterable
|
||||
|
@ -50,6 +47,7 @@ class ImageConfig:
|
|||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||
batch_id: str = None
|
||||
|
||||
# Options
|
||||
multiply: float = None
|
||||
|
@ -70,6 +68,7 @@ class ImageConfig:
|
|||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
||||
batch_id=overlay(other.batch_id, self.batch_id)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -84,6 +83,7 @@ class ImageConfig:
|
|||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"),
|
||||
shuffle_tags=data.get("shuffle_tags"),
|
||||
batch_id=data.get("batch_id")
|
||||
)
|
||||
|
||||
# Alternatively parse from dedicated `caption` attribute
|
||||
|
@ -168,6 +168,8 @@ class Dataset:
|
|||
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
||||
if 'local.yml' in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||
if 'batch_id.txt' in fileset:
|
||||
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
|
||||
|
||||
result = ImageConfig.fold(cfgs)
|
||||
if 'shuffle_tags.txt' in fileset:
|
||||
|
@ -262,6 +264,7 @@ class Dataset:
|
|||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout,
|
||||
shuffle_tags=config.shuffle_tags,
|
||||
batch_id=config.batch_id
|
||||
)
|
||||
items.append(item)
|
||||
except Exception as e:
|
||||
|
|
|
@ -124,7 +124,7 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
image: PIL.Image,
|
||||
caption: ImageCaption,
|
||||
aspects: list[float],
|
||||
|
@ -133,6 +133,7 @@ class ImageTrainItem:
|
|||
multiplier: float=1.0,
|
||||
cond_dropout=None,
|
||||
shuffle_tags=False,
|
||||
batch_id: str=None
|
||||
):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
|
@ -143,6 +144,8 @@ class ImageTrainItem:
|
|||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
||||
self.target_wh = None
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
@ -351,3 +354,6 @@ class ImageTrainItem:
|
|||
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
DEFAULT_BATCH_ID = "default_batch"
|
||||
|
|
14
train.py
14
train.py
|
@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
|
|||
|
||||
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||
from data.every_dream_validation import EveryDreamValidator
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
|
|||
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
||||
warn_bucket_dupe_ratio = 0.5
|
||||
|
||||
ar_buckets = set([tuple(i.target_wh) for i in items])
|
||||
def make_bucket_key(item):
|
||||
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
|
||||
|
||||
ar_buckets = set(make_bucket_key(i) for i in items)
|
||||
for ar_bucket in ar_buckets:
|
||||
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
|
||||
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
|
||||
runt_size = batch_size - (count % batch_size)
|
||||
bucket_dupe_ratio = runt_size / count
|
||||
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
|
||||
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
||||
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
|
||||
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
||||
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
||||
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
||||
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
|
||||
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
|
||||
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
|
|
Loading…
Reference in New Issue