Merge pull request #38 from noprompt/push-dlma-into-main

Push DLMA into `main`, improvements to `data.resolve`
This commit is contained in:
Victor Hall 2023-02-05 08:20:10 -05:00 committed by GitHub
commit d99b3b1d9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 161 additions and 179 deletions

View File

@ -15,15 +15,10 @@ limitations under the License.
""" """
import bisect import bisect
import math import math
import os
import logging
import copy import copy
import random import random
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
import data.aspects as aspects
import data.resolver as resolver
from colorama import Fore, Style
import PIL import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -32,24 +27,23 @@ class DataLoaderMultiAspect():
""" """
Data loader for multi-aspect-ratio training and bucketing Data loader for multi-aspect-ratio training and bucketing
data_root: root folder of training data image_train_items: list of `lImageTrainItem` objects
seed: random seed
batch_size: number of images per batch batch_size: number of images per batch
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
""" """
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
self.data_root = data_root
self.debug_level = debug_level
self.flip_p = flip_p
self.log_folder = log_folder
self.seed = seed self.seed = seed
self.batch_size = batch_size self.batch_size = batch_size
self.has_scanned = False # Prepare data
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
self.__prepare_train_data() # Initialize ratings
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for image in self.prepared_train_data:
self.rating_overall_sum += image.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
def __pick_multiplied_set(self, randomizer): def __pick_multiplied_set(self, randomizer):
""" """
@ -138,54 +132,6 @@ class DataLoaderMultiAspect():
return image_caption_pairs return image_caption_pairs
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
rating_overall_sum: float = 0.0
ratings_summed: list[float] = []
for image in self.prepared_train_data:
rating_overall_sum += image.caption.rating()
ratings_summed.append(rating_overall_sum)
return rating_overall_sum, ratings_summed
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
"""
Create ImageTrainItem objects with metadata for hydration later
"""
if not self.has_scanned:
self.has_scanned = True
logging.info(" Preloading images...")
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
image_paths = set(map(lambda item: item.pathname, items))
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
self.prepared_train_data = [item for item in items if item.error is None]
random.Random(self.seed).shuffle(self.prepared_train_data)
self.__report_errors(items)
def __report_errors(self, items: list[ImageTrainItem]):
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {item.error}")
undersized_items = [item for item in items if item.is_undersized]
if len(undersized_items) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for undersized_item in undersized_items:
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
undersized_images_file.write(message)
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
""" """
Picks a random subset of all images Picks a random subset of all images

View File

@ -16,108 +16,61 @@ limitations under the License.
import logging import logging
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect as dlma from data.data_loader import DataLoaderMultiAspect
import math
import data.dl_singleton as dls
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
import random import random
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import torch.nn.functional as F import torch.nn.functional as F
import numpy
class EveryDreamBatch(Dataset): class EveryDreamBatch(Dataset):
""" """
data_root: root path of all your training images, will be recursively searched for images data_loader: `DataLoaderMultiAspect` object
repeats: how many times to repeat each image in the dataset
flip_p: probability of flipping the image horizontally
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
batch_size: how many images to return in a batch
conditional_dropout: probability of dropping the caption for a given image conditional_dropout: probability of dropping the caption for a given image
resolution: max resolution (relative to square) crop_jitter: number of pixels to jitter the crop by, only for non-square images
jitter: number of pixels to jitter the crop by, only for non-square images seed: random seed
""" """
def __init__(self, def __init__(self,
data_root, data_loader: DataLoaderMultiAspect,
flip_p=0.0,
debug_level=0, debug_level=0,
batch_size=1,
conditional_dropout=0.02, conditional_dropout=0.02,
resolution=512,
crop_jitter=20, crop_jitter=20,
seed=555, seed=555,
tokenizer=None, tokenizer=None,
log_folder=None,
retain_contrast=False, retain_contrast=False,
write_schedule=False,
shuffle_tags=False, shuffle_tags=False,
rated_dataset=False, rated_dataset=False,
rated_dataset_dropout_target=0.5 rated_dataset_dropout_target=0.5
): ):
self.data_root = data_root self.data_loader = data_loader
self.batch_size = batch_size self.batch_size = data_loader.batch_size
self.debug_level = debug_level self.debug_level = debug_level
self.conditional_dropout = conditional_dropout self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0 self.unloaded_to_idx = 0
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.log_folder = log_folder
#print(f"tokenizer: {tokenizer}")
self.max_token_length = self.tokenizer.model_max_length self.max_token_length = self.tokenizer.model_max_length
self.retain_contrast = retain_contrast self.retain_contrast = retain_contrast
self.write_schedule = write_schedule
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.seed = seed self.seed = seed
self.rated_dataset = rated_dataset self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target self.rated_dataset_dropout_target = rated_dataset_dropout_target
# First epoch always trains on all images
if seed == -1: self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
seed = random.randint(0, 99999)
if not dls.shared_dataloader:
logging.info(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root,
seed=seed,
debug_level=debug_level,
batch_size=self.batch_size,
flip_p=flip_p,
resolution=resolution,
log_folder=self.log_folder,
)
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
num_images = len(self.image_train_items) num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
if self.write_schedule:
self.__write_batch_schedule(0)
def __write_batch_schedule(self, epoch_n):
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(self.image_train_items)):
try:
f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
def get_runts():
return dls.shared_dataloader.runts
def shuffle(self, epoch_n: int, max_epochs: int): def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1 self.seed += 1
if dls.shared_dataloader:
if self.rated_dataset: if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
else: else:
raise Exception("No dataloader singleton to shuffle") dropout_fraction = 1.0
if self.write_schedule: self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
self.__write_batch_schedule(epoch_n + 1)
def __len__(self): def __len__(self):
return len(self.image_train_items) return len(self.image_train_items)

View File

@ -4,18 +4,21 @@ import os
import random import random
import typing import typing
import zipfile import zipfile
import argparse
import PIL.Image as Image
import tqdm import tqdm
from colorama import Fore, Style from colorama import Fore, Style
from data.image_train_item import ImageCaption, ImageTrainItem from data.image_train_item import ImageCaption, ImageTrainItem
class DataResolver: class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555): def __init__(self, args: argparse.Namespace):
self.seed = seed """
self.aspects = aspects :param args: EveryDream configuration, an `argparse.Namespace` object.
self.flip_p = flip_p """
self.aspects = args.aspects
self.flip_p = args.flip_p
self.seed = args.seed
def image_train_items(self, data_root: str) -> list[ImageTrainItem]: def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
""" """
@ -173,8 +176,11 @@ class DirectoryResolver(DataResolver):
if os.path.isdir(current): if os.path.isdir(current):
yield from DirectoryResolver.recurse_data_root(current) yield from DirectoryResolver.recurse_data_root(current)
def strategy(data_root: str) -> typing.Type[DataResolver]:
def strategy(data_root: str): """
Determine the strategy to use for resolving the data.
:param data_root: The root directory or JSON file to resolve.
"""
if os.path.isfile(data_root) and data_root.endswith('.json'): if os.path.isfile(data_root) and data_root.endswith('.json'):
return JSONResolver return JSONResolver
@ -183,41 +189,37 @@ def strategy(data_root: str):
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]:
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
""" """
:param data_root: Directory or JSON file. Resolve the training data from the root path.
:param aspects: The list of aspect ratios to use :param path: The root path to resolve.
:param flip_p: The probability of flipping the image :param args: EveryDream configuration, an `argparse.Namespace` object.
""" """
if os.path.isfile(path) and path.endswith('.json'): resolver = strategy(path)
return JSONResolver(aspects, flip_p, seed).image_train_items(path) return resolver(args).image_train_items(path)
if os.path.isdir(path):
return DirectoryResolver(aspects, flip_p, seed).image_train_items(path)
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]: def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]:
""" """
Resolve the training data from the value. Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string. :param value: The value to resolve, either a dict, an array, or a string.
:param aspects: The list of aspect ratios to use :param args: EveryDream configuration, an `argparse.Namespace` object.
:param flip_p: The probability of flipping the image
""" """
if isinstance(value, str): if isinstance(value, str):
return resolve_root(value, aspects, flip_p) return resolve_root(value, args)
if isinstance(value, dict): if isinstance(value, dict):
resolver = value.get('resolver', None) resolver = value.get('resolver', None)
match resolver: match resolver:
case 'directory' | 'json': case 'directory' | 'json':
path = value.get('path', None) path = value.get('path', None)
return resolve_root(path, aspects, flip_p, seed) return resolve_root(path, args)
case 'multi': case 'multi':
items = [] return resolve(value.get('resolvers', []), args)
for resolver in value.get('resolvers', []):
items += resolve(resolver, aspects, flip_p, seed)
return items
case _: case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
if isinstance(value, list):
items = []
for item in value:
items += resolve(item, args)
return items

View File

@ -2,6 +2,7 @@ import json
import glob import glob
import os import os
import unittest import unittest
import argparse
import PIL.Image as Image import PIL.Image as Image
@ -10,13 +11,18 @@ import data.resolver as resolver
DATA_PATH = os.path.abspath('./test/data') DATA_PATH = os.path.abspath('./test/data')
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json') JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
ASPECTS = aspects.get_aspect_buckets(512)
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg') IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt') CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg') IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg') IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
ARGS = argparse.Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.5,
seed=42,
)
class TestResolve(unittest.TestCase): class TestResolve(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -51,7 +57,7 @@ class TestResolve(unittest.TestCase):
os.remove(file) os.remove(file)
def test_directory_resolve_with_str(self): def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ASPECTS) items = resolver.resolve(DATA_PATH, ARGS)
image_paths = [item.pathname for item in items] image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items] image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions] captions = [caption.get_caption() for caption in image_captions]
@ -69,7 +75,7 @@ class TestResolve(unittest.TestCase):
'path': DATA_PATH, 'path': DATA_PATH,
} }
items = resolver.resolve(data_root_spec, ASPECTS) items = resolver.resolve(data_root_spec, ARGS)
image_paths = [item.pathname for item in items] image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items] image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions] captions = [caption.get_caption() for caption in image_captions]
@ -82,7 +88,7 @@ class TestResolve(unittest.TestCase):
self.assertEqual(len(undersized_images), 1) self.assertEqual(len(undersized_images), 1)
def test_json_resolve_with_str(self): def test_json_resolve_with_str(self):
items = resolver.resolve(JSON_ROOT_PATH, ASPECTS) items = resolver.resolve(JSON_ROOT_PATH, ARGS)
image_paths = [item.pathname for item in items] image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items] image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions] captions = [caption.get_caption() for caption in image_captions]
@ -100,7 +106,7 @@ class TestResolve(unittest.TestCase):
'path': JSON_ROOT_PATH, 'path': JSON_ROOT_PATH,
} }
items = resolver.resolve(data_root_spec, ASPECTS) items = resolver.resolve(data_root_spec, ARGS)
image_paths = [item.pathname for item in items] image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items] image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions] captions = [caption.get_caption() for caption in image_captions]
@ -110,4 +116,22 @@ class TestResolve(unittest.TestCase):
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
undersized_images = list(filter(lambda i: i.is_undersized, items)) undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 1) self.assertEqual(len(undersized_images), 1)
def test_resolve_with_list(self):
data_root_spec = [
DATA_PATH,
JSON_ROOT_PATH,
]
items = resolver.resolve(data_root_spec, ARGS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(items), 6)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 2)

View File

@ -15,6 +15,7 @@ limitations under the License.
""" """
import os import os
import pprint
import sys import sys
import math import math
import signal import signal
@ -48,11 +49,15 @@ from accelerate.utils import set_seed
import wandb import wandb
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch from data.every_dream import EveryDreamBatch
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU from utils.gpu import GPU
import data.aspects as aspects
import data.resolver as resolver
_SIGTERM_EXIT_CODE = 130 _SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9 _VERY_LARGE_NUMBER = 1e9
@ -265,6 +270,8 @@ def setup_args(args):
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}")) logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
args.aspects = aspects.get_aspect_buckets(args.resolution)
return args return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
@ -288,6 +295,49 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
scaler.set_growth_factor(factor) scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100) scaler.set_growth_interval(100)
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {item.error}")
undersized_items = [item for item in items if item.is_undersized]
if len(undersized_items) > 0:
underized_log_path = os.path.join(log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for undersized_item in undersized_items:
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
undersized_images_file.write(message)
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
logging.info(" Preloading images...")
resolved_items = resolver.resolve(args.data_root, args)
report_image_train_item_problems(log_folder, resolved_items)
image_paths = set(map(lambda item: item.pathname, resolved_items))
# Remove erroneous items
image_train_items = [item for item in resolved_items if item.error is None]
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
return image_train_items
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
if args.write_schedule:
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(train_batch.image_train_items)):
try:
item = train_batch.image_train_items[i]
f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n")
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {item.pathname}")
def read_sample_prompts(sample_prompts_file_path: str): def read_sample_prompts(sample_prompts_file_path: str):
@ -556,23 +606,26 @@ def main(args):
) )
log_optimizer(optimizer, betas, epsilon) log_optimizer(optimizer, betas, epsilon)
image_train_items = resolve_image_train_items(args, log_folder)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
seed=seed,
batch_size=args.batch_size,
)
train_batch = EveryDreamBatch( train_batch = EveryDreamBatch(
data_root=args.data_root, data_loader=data_loader,
flip_p=args.flip_p,
debug_level=1, debug_level=1,
batch_size=args.batch_size,
conditional_dropout=args.cond_dropout, conditional_dropout=args.cond_dropout,
resolution=args.resolution,
tokenizer=tokenizer, tokenizer=tokenizer,
seed = seed, seed = seed,
log_folder=log_folder,
write_schedule=args.write_schedule,
shuffle_tags=args.shuffle_tags, shuffle_tags=args.shuffle_tags,
rated_dataset=args.rated_dataset, rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
) )
torch.cuda.benchmark = False torch.cuda.benchmark = False
epoch_len = math.ceil(len(train_batch) / args.batch_size) epoch_len = math.ceil(len(train_batch) / args.batch_size)
@ -592,10 +645,11 @@ def main(args):
if args.wandb is not None and args.wandb: if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, ) wandb.init(project=args.project_name, sync_tensorboard=True, )
log_writer = SummaryWriter(log_dir=log_folder, log_writer = SummaryWriter(
flush_secs=5, log_dir=log_folder,
comment="EveryDream2FineTunes", flush_secs=5,
) comment="EveryDream2FineTunes",
)
def log_args(log_writer, args): def log_args(log_writer, args):
arglog = "args:\n" arglog = "args:\n"
@ -732,6 +786,8 @@ def main(args):
# # discard the grads, just want to pin memory # # discard the grads, just want to pin memory
# optimizer.zero_grad(set_to_none=True) # optimizer.zero_grad(set_to_none=True)
write_batch_schedule(args, log_folder, train_batch, 0)
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
loss_epoch = [] loss_epoch = []
epoch_start_time = time.time() epoch_start_time = time.time()
@ -883,6 +939,7 @@ def main(args):
epoch_pbar.update(1) epoch_pbar.update(1)
if epoch < args.max_epochs - 1: if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
loss_local = sum(loss_epoch) / len(loss_epoch) loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
@ -909,7 +966,6 @@ def main(args):
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args): def update_old_args(t_args):
""" """
Update old args to new args to deal with json config loading and missing args for compatibility Update old args to new args to deal with json config loading and missing args for compatibility
@ -947,7 +1003,6 @@ if __name__ == "__main__":
t_args = argparse.Namespace() t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f)) t_args.__dict__.update(json.load(f))
update_old_args(t_args) # update args to support older configs update_old_args(t_args) # update args to support older configs
print(f" args: \n{t_args.__dict__}")
args = argparser.parse_args(namespace=t_args) args = argparser.parse_args(namespace=t_args)
else: else:
print("No config file specified, using command line args") print("No config file specified, using command line args")
@ -996,4 +1051,6 @@ if __name__ == "__main__":
args, _ = argparser.parse_known_args() args, _ = argparser.parse_known_args()
print(f" Args:")
pprint.pprint(args.__dict__)
main(args) main(args)