add a batch_id.txt file to subfolders or a batch_id key to local yaml to force images for that folder to be processed in the same batch
This commit is contained in:
parent
7e09b6dc29
commit
53d0686086
|
@ -248,7 +248,7 @@ def __get_all_aspects():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
|
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
|
||||||
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
||||||
if x <= 1:
|
if x <= 1:
|
||||||
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
||||||
|
|
|
@ -21,7 +21,10 @@ import math
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem
|
from itertools import groupby
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
@ -103,15 +106,29 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
buckets = {}
|
buckets = {}
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
def append_to_bucket(item, batch_id):
|
||||||
|
bucket_key = (batch_id, item.target_wh[0], item.target_wh[1])
|
||||||
|
if bucket_key not in buckets:
|
||||||
|
buckets[bucket_key] = []
|
||||||
|
buckets[bucket_key].append(item)
|
||||||
|
|
||||||
for image_caption_pair in picked_images:
|
for image_caption_pair in picked_images:
|
||||||
image_caption_pair.runt_size = 0
|
image_caption_pair.runt_size = 0
|
||||||
target_wh = image_caption_pair.target_wh
|
batch_id = image_caption_pair.batch_id
|
||||||
|
append_to_bucket(image_caption_pair, batch_id)
|
||||||
|
|
||||||
if (target_wh[0],target_wh[1]) not in buckets:
|
# shunt any runts from "named" buckets into the appropriate "general" buckets
|
||||||
buckets[(target_wh[0],target_wh[1])] = []
|
for bucket in [b for b in buckets if b[0] != DEFAULT_BATCH_ID]:
|
||||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
truncate_count = len(buckets[bucket]) % batch_size
|
||||||
|
for runt in buckets[bucket][-truncate_count:]:
|
||||||
|
append_to_bucket(runt, DEFAULT_BATCH_ID)
|
||||||
|
del buckets[bucket][-truncate_count:]
|
||||||
|
if len(buckets[bucket]) == 0:
|
||||||
|
del buckets[bucket]
|
||||||
|
|
||||||
for bucket in buckets:
|
# handle runts in "general" buckets by randomly duplicating items
|
||||||
|
for bucket in [b for b in buckets if b[0] == DEFAULT_BATCH_ID]:
|
||||||
truncate_count = len(buckets[bucket]) % batch_size
|
truncate_count = len(buckets[bucket]) % batch_size
|
||||||
if truncate_count > 0:
|
if truncate_count > 0:
|
||||||
runt_bucket = buckets[bucket][-truncate_count:]
|
runt_bucket = buckets[bucket][-truncate_count:]
|
||||||
|
@ -132,6 +149,7 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
Picks a random subset of all images
|
Picks a random subset of all images
|
||||||
|
@ -175,3 +193,5 @@ class DataLoaderMultiAspect():
|
||||||
for item in self.prepared_train_data:
|
for item in self.prepared_train_data:
|
||||||
self.rating_overall_sum += item.caption.rating()
|
self.rating_overall_sum += item.caption.rating()
|
||||||
self.ratings_summed.append(self.rating_overall_sum)
|
self.ratings_summed.append(self.rating_overall_sum)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import yaml
|
import yaml
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from functools import total_ordering
|
from attrs import define, field
|
||||||
from attrs import define, field, Factory
|
|
||||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
from utils.fs_helpers import *
|
from utils.fs_helpers import *
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
@ -50,6 +47,7 @@ class ImageConfig:
|
||||||
rating: float = None
|
rating: float = None
|
||||||
max_caption_length: int = None
|
max_caption_length: int = None
|
||||||
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||||
|
batch_id: str = None
|
||||||
|
|
||||||
# Options
|
# Options
|
||||||
multiply: float = None
|
multiply: float = None
|
||||||
|
@ -70,6 +68,7 @@ class ImageConfig:
|
||||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||||
flip_p=overlay(other.flip_p, self.flip_p),
|
flip_p=overlay(other.flip_p, self.flip_p),
|
||||||
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
||||||
|
batch_id=overlay(other.batch_id, self.batch_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -84,6 +83,7 @@ class ImageConfig:
|
||||||
cond_dropout=data.get("cond_dropout"),
|
cond_dropout=data.get("cond_dropout"),
|
||||||
flip_p=data.get("flip_p"),
|
flip_p=data.get("flip_p"),
|
||||||
shuffle_tags=data.get("shuffle_tags"),
|
shuffle_tags=data.get("shuffle_tags"),
|
||||||
|
batch_id=data.get("batch_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Alternatively parse from dedicated `caption` attribute
|
# Alternatively parse from dedicated `caption` attribute
|
||||||
|
@ -168,6 +168,8 @@ class Dataset:
|
||||||
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
||||||
if 'local.yml' in fileset:
|
if 'local.yml' in fileset:
|
||||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||||
|
if 'batch_id.txt' in fileset:
|
||||||
|
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
|
||||||
|
|
||||||
result = ImageConfig.fold(cfgs)
|
result = ImageConfig.fold(cfgs)
|
||||||
if 'shuffle_tags.txt' in fileset:
|
if 'shuffle_tags.txt' in fileset:
|
||||||
|
@ -262,6 +264,7 @@ class Dataset:
|
||||||
multiplier=config.multiply or 1.0,
|
multiplier=config.multiply or 1.0,
|
||||||
cond_dropout=config.cond_dropout,
|
cond_dropout=config.cond_dropout,
|
||||||
shuffle_tags=config.shuffle_tags,
|
shuffle_tags=config.shuffle_tags,
|
||||||
|
batch_id=config.batch_id
|
||||||
)
|
)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -133,6 +133,7 @@ class ImageTrainItem:
|
||||||
multiplier: float=1.0,
|
multiplier: float=1.0,
|
||||||
cond_dropout=None,
|
cond_dropout=None,
|
||||||
shuffle_tags=False,
|
shuffle_tags=False,
|
||||||
|
batch_id: str=None
|
||||||
):
|
):
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.aspects = aspects
|
self.aspects = aspects
|
||||||
|
@ -143,6 +144,8 @@ class ImageTrainItem:
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.cond_dropout = cond_dropout
|
self.cond_dropout = cond_dropout
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
|
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
||||||
|
self.target_wh = None
|
||||||
|
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
if image is None or len(image) == 0:
|
if image is None or len(image) == 0:
|
||||||
|
@ -351,3 +354,6 @@ class ImageTrainItem:
|
||||||
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BATCH_ID = "default_batch"
|
||||||
|
|
14
train.py
14
train.py
|
@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||||
from data.every_dream_validation import EveryDreamValidator
|
from data.every_dream_validation import EveryDreamValidator
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.isolate_rng import isolate_rng
|
from utils.isolate_rng import isolate_rng
|
||||||
|
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
|
||||||
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
||||||
warn_bucket_dupe_ratio = 0.5
|
warn_bucket_dupe_ratio = 0.5
|
||||||
|
|
||||||
ar_buckets = set([tuple(i.target_wh) for i in items])
|
def make_bucket_key(item):
|
||||||
|
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
|
||||||
|
|
||||||
|
ar_buckets = set(make_bucket_key(i) for i in items)
|
||||||
for ar_bucket in ar_buckets:
|
for ar_bucket in ar_buckets:
|
||||||
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
|
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
|
||||||
runt_size = batch_size - (count % batch_size)
|
runt_size = batch_size - (count % batch_size)
|
||||||
bucket_dupe_ratio = runt_size / count
|
bucket_dupe_ratio = runt_size / count
|
||||||
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
||||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
|
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
|
||||||
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
||||||
|
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
|
||||||
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
||||||
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
||||||
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
||||||
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
|
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
|
||||||
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
|
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
|
||||||
|
|
||||||
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||||
|
|
Loading…
Reference in New Issue