Merge pull request #38 from noprompt/push-dlma-into-main
Push DLMA into `main`, improvements to `data.resolve`
This commit is contained in:
commit
d99b3b1d9b
|
@ -15,15 +15,10 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
import bisect
|
import bisect
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
import data.aspects as aspects
|
|
||||||
import data.resolver as resolver
|
|
||||||
from colorama import Fore, Style
|
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
@ -32,24 +27,23 @@ class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Data loader for multi-aspect-ratio training and bucketing
|
Data loader for multi-aspect-ratio training and bucketing
|
||||||
|
|
||||||
data_root: root folder of training data
|
image_train_items: list of `lImageTrainItem` objects
|
||||||
|
seed: random seed
|
||||||
batch_size: number of images per batch
|
batch_size: number of images per batch
|
||||||
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
|
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
|
||||||
self.data_root = data_root
|
|
||||||
self.debug_level = debug_level
|
|
||||||
self.flip_p = flip_p
|
|
||||||
self.log_folder = log_folder
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.has_scanned = False
|
# Prepare data
|
||||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
self.prepared_train_data = image_train_items
|
||||||
|
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||||
self.__prepare_train_data()
|
# Initialize ratings
|
||||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
self.rating_overall_sum: float = 0.0
|
||||||
|
self.ratings_summed: list[float] = []
|
||||||
|
for image in self.prepared_train_data:
|
||||||
|
self.rating_overall_sum += image.caption.rating()
|
||||||
|
self.ratings_summed.append(self.rating_overall_sum)
|
||||||
|
|
||||||
def __pick_multiplied_set(self, randomizer):
|
def __pick_multiplied_set(self, randomizer):
|
||||||
"""
|
"""
|
||||||
|
@ -138,54 +132,6 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
return image_caption_pairs
|
return image_caption_pairs
|
||||||
|
|
||||||
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
|
|
||||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
|
||||||
|
|
||||||
rating_overall_sum: float = 0.0
|
|
||||||
ratings_summed: list[float] = []
|
|
||||||
for image in self.prepared_train_data:
|
|
||||||
rating_overall_sum += image.caption.rating()
|
|
||||||
ratings_summed.append(rating_overall_sum)
|
|
||||||
|
|
||||||
return rating_overall_sum, ratings_summed
|
|
||||||
|
|
||||||
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
|
|
||||||
"""
|
|
||||||
Create ImageTrainItem objects with metadata for hydration later
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.has_scanned:
|
|
||||||
self.has_scanned = True
|
|
||||||
|
|
||||||
logging.info(" Preloading images...")
|
|
||||||
|
|
||||||
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
|
||||||
image_paths = set(map(lambda item: item.pathname, items))
|
|
||||||
|
|
||||||
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
|
|
||||||
|
|
||||||
self.prepared_train_data = [item for item in items if item.error is None]
|
|
||||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
|
||||||
self.__report_errors(items)
|
|
||||||
|
|
||||||
def __report_errors(self, items: list[ImageTrainItem]):
|
|
||||||
for item in items:
|
|
||||||
if item.error is not None:
|
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
|
||||||
logging.error(f" *** exception: {item.error}")
|
|
||||||
|
|
||||||
undersized_items = [item for item in items if item.is_undersized]
|
|
||||||
|
|
||||||
if len(undersized_items) > 0:
|
|
||||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
|
||||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
|
||||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
|
||||||
with open(underized_log_path, "w") as undersized_images_file:
|
|
||||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
|
||||||
for undersized_item in undersized_items:
|
|
||||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
|
||||||
undersized_images_file.write(message)
|
|
||||||
|
|
||||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
Picks a random subset of all images
|
Picks a random subset of all images
|
||||||
|
|
|
@ -16,108 +16,61 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from data.data_loader import DataLoaderMultiAspect as dlma
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
import math
|
|
||||||
import data.dl_singleton as dls
|
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
import random
|
import random
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy
|
|
||||||
|
|
||||||
class EveryDreamBatch(Dataset):
|
class EveryDreamBatch(Dataset):
|
||||||
"""
|
"""
|
||||||
data_root: root path of all your training images, will be recursively searched for images
|
data_loader: `DataLoaderMultiAspect` object
|
||||||
repeats: how many times to repeat each image in the dataset
|
|
||||||
flip_p: probability of flipping the image horizontally
|
|
||||||
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
||||||
batch_size: how many images to return in a batch
|
|
||||||
conditional_dropout: probability of dropping the caption for a given image
|
conditional_dropout: probability of dropping the caption for a given image
|
||||||
resolution: max resolution (relative to square)
|
crop_jitter: number of pixels to jitter the crop by, only for non-square images
|
||||||
jitter: number of pixels to jitter the crop by, only for non-square images
|
seed: random seed
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data_root,
|
data_loader: DataLoaderMultiAspect,
|
||||||
flip_p=0.0,
|
|
||||||
debug_level=0,
|
debug_level=0,
|
||||||
batch_size=1,
|
|
||||||
conditional_dropout=0.02,
|
conditional_dropout=0.02,
|
||||||
resolution=512,
|
|
||||||
crop_jitter=20,
|
crop_jitter=20,
|
||||||
seed=555,
|
seed=555,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
log_folder=None,
|
|
||||||
retain_contrast=False,
|
retain_contrast=False,
|
||||||
write_schedule=False,
|
|
||||||
shuffle_tags=False,
|
shuffle_tags=False,
|
||||||
rated_dataset=False,
|
rated_dataset=False,
|
||||||
rated_dataset_dropout_target=0.5
|
rated_dataset_dropout_target=0.5
|
||||||
):
|
):
|
||||||
self.data_root = data_root
|
self.data_loader = data_loader
|
||||||
self.batch_size = batch_size
|
self.batch_size = data_loader.batch_size
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.conditional_dropout = conditional_dropout
|
self.conditional_dropout = conditional_dropout
|
||||||
self.crop_jitter = crop_jitter
|
self.crop_jitter = crop_jitter
|
||||||
self.unloaded_to_idx = 0
|
self.unloaded_to_idx = 0
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.log_folder = log_folder
|
|
||||||
#print(f"tokenizer: {tokenizer}")
|
|
||||||
self.max_token_length = self.tokenizer.model_max_length
|
self.max_token_length = self.tokenizer.model_max_length
|
||||||
self.retain_contrast = retain_contrast
|
self.retain_contrast = retain_contrast
|
||||||
self.write_schedule = write_schedule
|
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.rated_dataset = rated_dataset
|
self.rated_dataset = rated_dataset
|
||||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||||
|
# First epoch always trains on all images
|
||||||
if seed == -1:
|
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
|
||||||
seed = random.randint(0, 99999)
|
|
||||||
|
|
||||||
if not dls.shared_dataloader:
|
|
||||||
logging.info(" * Creating new dataloader singleton")
|
|
||||||
dls.shared_dataloader = dlma(data_root=data_root,
|
|
||||||
seed=seed,
|
|
||||||
debug_level=debug_level,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
flip_p=flip_p,
|
|
||||||
resolution=resolution,
|
|
||||||
log_folder=self.log_folder,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
|
||||||
|
|
||||||
num_images = len(self.image_train_items)
|
num_images = len(self.image_train_items)
|
||||||
|
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||||
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
|
||||||
if self.write_schedule:
|
|
||||||
self.__write_batch_schedule(0)
|
|
||||||
|
|
||||||
def __write_batch_schedule(self, epoch_n):
|
|
||||||
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
|
||||||
for i in range(len(self.image_train_items)):
|
|
||||||
try:
|
|
||||||
f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
|
|
||||||
|
|
||||||
def get_runts():
|
|
||||||
return dls.shared_dataloader.runts
|
|
||||||
|
|
||||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||||
self.seed += 1
|
self.seed += 1
|
||||||
if dls.shared_dataloader:
|
|
||||||
if self.rated_dataset:
|
|
||||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
|
||||||
else:
|
|
||||||
dropout_fraction = 1.0
|
|
||||||
|
|
||||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
|
if self.rated_dataset:
|
||||||
|
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||||
else:
|
else:
|
||||||
raise Exception("No dataloader singleton to shuffle")
|
dropout_fraction = 1.0
|
||||||
|
|
||||||
if self.write_schedule:
|
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||||
self.__write_batch_schedule(epoch_n + 1)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_train_items)
|
return len(self.image_train_items)
|
||||||
|
|
|
@ -4,18 +4,21 @@ import os
|
||||||
import random
|
import random
|
||||||
import typing
|
import typing
|
||||||
import zipfile
|
import zipfile
|
||||||
|
import argparse
|
||||||
|
|
||||||
import PIL.Image as Image
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
|
|
||||||
class DataResolver:
|
class DataResolver:
|
||||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
|
def __init__(self, args: argparse.Namespace):
|
||||||
self.seed = seed
|
"""
|
||||||
self.aspects = aspects
|
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||||
self.flip_p = flip_p
|
"""
|
||||||
|
self.aspects = args.aspects
|
||||||
|
self.flip_p = args.flip_p
|
||||||
|
self.seed = args.seed
|
||||||
|
|
||||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
|
@ -173,8 +176,11 @@ class DirectoryResolver(DataResolver):
|
||||||
if os.path.isdir(current):
|
if os.path.isdir(current):
|
||||||
yield from DirectoryResolver.recurse_data_root(current)
|
yield from DirectoryResolver.recurse_data_root(current)
|
||||||
|
|
||||||
|
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||||
def strategy(data_root: str):
|
"""
|
||||||
|
Determine the strategy to use for resolving the data.
|
||||||
|
:param data_root: The root directory or JSON file to resolve.
|
||||||
|
"""
|
||||||
if os.path.isfile(data_root) and data_root.endswith('.json'):
|
if os.path.isfile(data_root) and data_root.endswith('.json'):
|
||||||
return JSONResolver
|
return JSONResolver
|
||||||
|
|
||||||
|
@ -183,41 +189,37 @@ def strategy(data_root: str):
|
||||||
|
|
||||||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||||
|
|
||||||
|
def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
|
|
||||||
"""
|
"""
|
||||||
:param data_root: Directory or JSON file.
|
Resolve the training data from the root path.
|
||||||
:param aspects: The list of aspect ratios to use
|
:param path: The root path to resolve.
|
||||||
:param flip_p: The probability of flipping the image
|
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(path) and path.endswith('.json'):
|
resolver = strategy(path)
|
||||||
return JSONResolver(aspects, flip_p, seed).image_train_items(path)
|
return resolver(args).image_train_items(path)
|
||||||
|
|
||||||
if os.path.isdir(path):
|
def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||||
return DirectoryResolver(aspects, flip_p, seed).image_train_items(path)
|
|
||||||
|
|
||||||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
|
||||||
|
|
||||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
|
|
||||||
"""
|
"""
|
||||||
Resolve the training data from the value.
|
Resolve the training data from the value.
|
||||||
:param value: The value to resolve, either a dict or a string.
|
:param value: The value to resolve, either a dict, an array, or a string.
|
||||||
:param aspects: The list of aspect ratios to use
|
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||||
:param flip_p: The probability of flipping the image
|
|
||||||
"""
|
"""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return resolve_root(value, aspects, flip_p)
|
return resolve_root(value, args)
|
||||||
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
resolver = value.get('resolver', None)
|
resolver = value.get('resolver', None)
|
||||||
match resolver:
|
match resolver:
|
||||||
case 'directory' | 'json':
|
case 'directory' | 'json':
|
||||||
path = value.get('path', None)
|
path = value.get('path', None)
|
||||||
return resolve_root(path, aspects, flip_p, seed)
|
return resolve_root(path, args)
|
||||||
case 'multi':
|
case 'multi':
|
||||||
items = []
|
return resolve(value.get('resolvers', []), args)
|
||||||
for resolver in value.get('resolvers', []):
|
|
||||||
items += resolve(resolver, aspects, flip_p, seed)
|
|
||||||
return items
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
items = []
|
||||||
|
for item in value:
|
||||||
|
items += resolve(item, args)
|
||||||
|
return items
|
|
@ -2,6 +2,7 @@ import json
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
import argparse
|
||||||
|
|
||||||
import PIL.Image as Image
|
import PIL.Image as Image
|
||||||
|
|
||||||
|
@ -10,13 +11,18 @@ import data.resolver as resolver
|
||||||
|
|
||||||
DATA_PATH = os.path.abspath('./test/data')
|
DATA_PATH = os.path.abspath('./test/data')
|
||||||
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
|
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
|
||||||
ASPECTS = aspects.get_aspect_buckets(512)
|
|
||||||
|
|
||||||
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
|
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
|
||||||
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
|
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
|
||||||
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
|
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
|
||||||
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
|
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
|
||||||
|
|
||||||
|
ARGS = argparse.Namespace(
|
||||||
|
aspects=aspects.get_aspect_buckets(512),
|
||||||
|
flip_p=0.5,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
class TestResolve(unittest.TestCase):
|
class TestResolve(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
@ -51,7 +57,7 @@ class TestResolve(unittest.TestCase):
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
|
|
||||||
def test_directory_resolve_with_str(self):
|
def test_directory_resolve_with_str(self):
|
||||||
items = resolver.resolve(DATA_PATH, ASPECTS)
|
items = resolver.resolve(DATA_PATH, ARGS)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -69,7 +75,7 @@ class TestResolve(unittest.TestCase):
|
||||||
'path': DATA_PATH,
|
'path': DATA_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
items = resolver.resolve(data_root_spec, ARGS)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -82,7 +88,7 @@ class TestResolve(unittest.TestCase):
|
||||||
self.assertEqual(len(undersized_images), 1)
|
self.assertEqual(len(undersized_images), 1)
|
||||||
|
|
||||||
def test_json_resolve_with_str(self):
|
def test_json_resolve_with_str(self):
|
||||||
items = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
|
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -100,7 +106,7 @@ class TestResolve(unittest.TestCase):
|
||||||
'path': JSON_ROOT_PATH,
|
'path': JSON_ROOT_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
items = resolver.resolve(data_root_spec, ARGS)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -111,3 +117,21 @@ class TestResolve(unittest.TestCase):
|
||||||
|
|
||||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||||
self.assertEqual(len(undersized_images), 1)
|
self.assertEqual(len(undersized_images), 1)
|
||||||
|
|
||||||
|
def test_resolve_with_list(self):
|
||||||
|
data_root_spec = [
|
||||||
|
DATA_PATH,
|
||||||
|
JSON_ROOT_PATH,
|
||||||
|
]
|
||||||
|
|
||||||
|
items = resolver.resolve(data_root_spec, ARGS)
|
||||||
|
image_paths = [item.pathname for item in items]
|
||||||
|
image_captions = [item.caption for item in items]
|
||||||
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
|
||||||
|
self.assertEqual(len(items), 6)
|
||||||
|
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
|
||||||
|
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
|
||||||
|
|
||||||
|
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||||
|
self.assertEqual(len(undersized_images), 2)
|
83
train.py
83
train.py
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pprint
|
||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
import signal
|
import signal
|
||||||
|
@ -48,11 +49,15 @@ from accelerate.utils import set_seed
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch
|
from data.every_dream import EveryDreamBatch
|
||||||
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.gpu import GPU
|
from utils.gpu import GPU
|
||||||
|
import data.aspects as aspects
|
||||||
|
import data.resolver as resolver
|
||||||
|
|
||||||
_SIGTERM_EXIT_CODE = 130
|
_SIGTERM_EXIT_CODE = 130
|
||||||
_VERY_LARGE_NUMBER = 1e9
|
_VERY_LARGE_NUMBER = 1e9
|
||||||
|
@ -265,6 +270,8 @@ def setup_args(args):
|
||||||
|
|
||||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||||
|
|
||||||
|
args.aspects = aspects.get_aspect_buckets(args.resolution)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
|
@ -289,6 +296,49 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(100)
|
scaler.set_growth_interval(100)
|
||||||
|
|
||||||
|
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
||||||
|
for item in items:
|
||||||
|
if item.error is not None:
|
||||||
|
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||||
|
logging.error(f" *** exception: {item.error}")
|
||||||
|
|
||||||
|
undersized_items = [item for item in items if item.is_undersized]
|
||||||
|
|
||||||
|
if len(undersized_items) > 0:
|
||||||
|
underized_log_path = os.path.join(log_folder, "undersized_images.txt")
|
||||||
|
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||||
|
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||||
|
with open(underized_log_path, "w") as undersized_images_file:
|
||||||
|
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||||
|
for undersized_item in undersized_items:
|
||||||
|
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||||
|
undersized_images_file.write(message)
|
||||||
|
|
||||||
|
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
||||||
|
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||||
|
logging.info(" Preloading images...")
|
||||||
|
|
||||||
|
resolved_items = resolver.resolve(args.data_root, args)
|
||||||
|
report_image_train_item_problems(log_folder, resolved_items)
|
||||||
|
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
||||||
|
|
||||||
|
# Remove erroneous items
|
||||||
|
image_train_items = [item for item in resolved_items if item.error is None]
|
||||||
|
|
||||||
|
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||||
|
|
||||||
|
return image_train_items
|
||||||
|
|
||||||
|
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
||||||
|
if args.write_schedule:
|
||||||
|
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||||
|
for i in range(len(train_batch.image_train_items)):
|
||||||
|
try:
|
||||||
|
item = train_batch.image_train_items[i]
|
||||||
|
f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f" * Error writing to batch schedule for file path: {item.pathname}")
|
||||||
|
|
||||||
|
|
||||||
def read_sample_prompts(sample_prompts_file_path: str):
|
def read_sample_prompts(sample_prompts_file_path: str):
|
||||||
sample_prompts = []
|
sample_prompts = []
|
||||||
|
@ -557,17 +607,20 @@ def main(args):
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon)
|
log_optimizer(optimizer, betas, epsilon)
|
||||||
|
|
||||||
train_batch = EveryDreamBatch(
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
data_root=args.data_root,
|
|
||||||
flip_p=args.flip_p,
|
data_loader = DataLoaderMultiAspect(
|
||||||
debug_level=1,
|
image_train_items=image_train_items,
|
||||||
|
seed=seed,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_batch = EveryDreamBatch(
|
||||||
|
data_loader=data_loader,
|
||||||
|
debug_level=1,
|
||||||
conditional_dropout=args.cond_dropout,
|
conditional_dropout=args.cond_dropout,
|
||||||
resolution=args.resolution,
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
log_folder=log_folder,
|
|
||||||
write_schedule=args.write_schedule,
|
|
||||||
shuffle_tags=args.shuffle_tags,
|
shuffle_tags=args.shuffle_tags,
|
||||||
rated_dataset=args.rated_dataset,
|
rated_dataset=args.rated_dataset,
|
||||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||||
|
@ -592,10 +645,11 @@ def main(args):
|
||||||
if args.wandb is not None and args.wandb:
|
if args.wandb is not None and args.wandb:
|
||||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||||
|
|
||||||
log_writer = SummaryWriter(log_dir=log_folder,
|
log_writer = SummaryWriter(
|
||||||
flush_secs=5,
|
log_dir=log_folder,
|
||||||
comment="EveryDream2FineTunes",
|
flush_secs=5,
|
||||||
)
|
comment="EveryDream2FineTunes",
|
||||||
|
)
|
||||||
|
|
||||||
def log_args(log_writer, args):
|
def log_args(log_writer, args):
|
||||||
arglog = "args:\n"
|
arglog = "args:\n"
|
||||||
|
@ -732,6 +786,8 @@ def main(args):
|
||||||
# # discard the grads, just want to pin memory
|
# # discard the grads, just want to pin memory
|
||||||
# optimizer.zero_grad(set_to_none=True)
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
write_batch_schedule(args, log_folder, train_batch, 0)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
|
@ -883,6 +939,7 @@ def main(args):
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if epoch < args.max_epochs - 1:
|
if epoch < args.max_epochs - 1:
|
||||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||||
|
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||||
|
|
||||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
@ -909,7 +966,6 @@ def main(args):
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
|
||||||
def update_old_args(t_args):
|
def update_old_args(t_args):
|
||||||
"""
|
"""
|
||||||
Update old args to new args to deal with json config loading and missing args for compatibility
|
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||||
|
@ -947,7 +1003,6 @@ if __name__ == "__main__":
|
||||||
t_args = argparse.Namespace()
|
t_args = argparse.Namespace()
|
||||||
t_args.__dict__.update(json.load(f))
|
t_args.__dict__.update(json.load(f))
|
||||||
update_old_args(t_args) # update args to support older configs
|
update_old_args(t_args) # update args to support older configs
|
||||||
print(f" args: \n{t_args.__dict__}")
|
|
||||||
args = argparser.parse_args(namespace=t_args)
|
args = argparser.parse_args(namespace=t_args)
|
||||||
else:
|
else:
|
||||||
print("No config file specified, using command line args")
|
print("No config file specified, using command line args")
|
||||||
|
@ -996,4 +1051,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
args, _ = argparser.parse_known_args()
|
args, _ = argparser.parse_known_args()
|
||||||
|
|
||||||
|
print(f" Args:")
|
||||||
|
pprint.pprint(args.__dict__)
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue