This commit is contained in:
Victor Hall 2023-02-08 11:07:24 -05:00
commit f275f20412
7 changed files with 89 additions and 85 deletions

View File

@ -101,7 +101,7 @@
"!pip install -q protobuf==3.20.1\n", "!pip install -q protobuf==3.20.1\n",
"!pip install -q wandb==0.13.6\n", "!pip install -q wandb==0.13.6\n",
"!pip install -q pyre-extensions==0.0.23\n", "!pip install -q pyre-extensions==0.0.23\n",
"!pip install -q xformers==0.0.17.dev435\n", "!pip install -q xformers==0.0.16\n",
"!pip install -q pytorch-lightning==1.6.5\n", "!pip install -q pytorch-lightning==1.6.5\n",
"!pip install -q OmegaConf==2.2.3\n", "!pip install -q OmegaConf==2.2.3\n",
"!pip install -q numpy==1.23.5\n", "!pip install -q numpy==1.23.5\n",

View File

@ -14,13 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import bisect import bisect
from functools import reduce import logging
import os.path
from collections import defaultdict
import math import math
import copy
import random import random
from data.image_train_item import ImageTrainItem, ImageCaption from data.image_train_item import ImageTrainItem
import PIL import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -38,43 +39,40 @@ class DataLoaderMultiAspect():
self.prepared_train_data = image_train_items self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data) random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if expected_epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
self.rating_overall_sum: float = 0.0 self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = [] self.ratings_summed: list[float] = []
self.__update_rating_sums() self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer):
def __pick_multiplied_set(self, randomizer: random.Random):
""" """
Deals with multiply.txt whole and fractional numbers Deals with multiply.txt whole and fractional numbers
""" """
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
epoch_size = len(self.prepared_train_data)
picked_images = [] picked_images = []
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
# add by whole number part first and decrement multiplier in copy for iti in self.prepared_train_data:
for iti in data_copy: multiplier = iti.multiplier
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}") while multiplier >= 1:
while iti.multiplier >= 1.0:
picked_images.append(iti) picked_images.append(iti)
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}") multiplier -= 1
iti.multiplier -= 1.0 # fractional remainders must be dealt with separately
if multiplier > 0:
directory = os.path.dirname(iti.pathname)
fractional_images_per_directory[directory].append(iti)
remaining = epoch_size - len(picked_images) # resolve fractional parts per-directory
for _, fractional_items in fractional_images_per_directory.items():
randomizer.shuffle(fractional_items)
multiplier = fractional_items[0].multiplier % 1.0
count_to_take = math.ceil(multiplier * len(fractional_items))
picked_images.extend(fractional_items[:count_to_take])
assert remaining >= 0, "Something went wrong with the multiplier calculation"
# add by remaining fractional numbers by random chance
while remaining > 0:
for iti in data_copy:
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
picked_images.append(iti)
remaining -= 1
iti.multiplier = 0.0
if remaining <= 0:
break
del data_copy
return picked_images return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]: def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
@ -110,20 +108,19 @@ class DataLoaderMultiAspect():
buckets[(target_wh[0],target_wh[1])] = [] buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
if len(buckets) > 1: for bucket in buckets:
for bucket in buckets: truncate_count = len(buckets[bucket]) % batch_size
truncate_count = len(buckets[bucket]) % batch_size if truncate_count > 0:
if truncate_count > 0: runt_bucket = buckets[bucket][-truncate_count:]
runt_bucket = buckets[bucket][-truncate_count:] for item in runt_bucket:
for item in runt_bucket: item.runt_size = truncate_count
item.runt_size = truncate_count while len(runt_bucket) < batch_size:
while len(runt_bucket) < batch_size: runt_bucket.append(random.choice(runt_bucket))
runt_bucket.append(random.choice(runt_bucket))
current_bucket_size = len(buckets[bucket]) current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket) buckets[bucket].extend(runt_bucket)
# flatten the buckets # flatten the buckets
items: list[ImageTrainItem] = [] items: list[ImageTrainItem] = []

View File

@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
num_images = len(self.image_train_items) num_images = len(self.image_train_items)
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}") logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
self.__update_image_train_items(1.0)
return items
def shuffle(self, epoch_n: int, max_epochs: int): def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1 self.seed += 1

View File

@ -1,7 +1,9 @@
import copy
import json import json
import logging
import math import math
import random import random
from typing import Callable, Any, Optional from typing import Callable, Any, Optional, Generator
from argparse import Namespace from argparse import Namespace
import torch import torch
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
remaining_items = list(items_copy[split_item_count:]) remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items return split_items, remaining_items
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
class EveryDreamValidator: class EveryDreamValidator:
def __init__(self, def __init__(self,
val_config_path: Optional[str], val_config_path: Optional[str],
default_batch_size: int, default_batch_size: int,
resolution: int,
log_writer: SummaryWriter): log_writer: SummaryWriter):
self.val_dataloader = None self.val_dataloader = None
self.train_overlapping_dataloader = None self.train_overlapping_dataloader = None
self.log_writer = log_writer self.log_writer = log_writer
self.resolution = resolution
self.config = { self.config = {
'batch_size': default_batch_size, 'batch_size': default_batch_size,
'every_n_epochs': 1, 'every_n_epochs': 1,
'seed': 555, 'seed': 555,
'validate_training': True,
'val_split_mode': 'automatic', 'val_split_mode': 'automatic',
'val_split_proportion': 0.15, 'val_split_proportion': 0.15,
@ -120,21 +128,24 @@ class EveryDreamValidator:
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode'] val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
val_split_proportion = self.config['val_split_proportion'] val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items remaining_train_items = image_train_items
if val_split_mode == 'none': if val_split_mode is None or val_split_mode == 'none':
return None, image_train_items return None, image_train_items
elif val_split_mode == 'automatic': elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size) val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual': elif val_split_mode == 'manual':
args = Namespace( args = Namespace(
aspects=aspects.get_aspect_buckets(512), aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0, flip_p=0.0,
seed=self.seed, seed=self.seed,
) )
val_data_root = self.config['val_data_root'] val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args) val_items = resolver.resolve_root(val_data_root, args)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
else: else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
@ -149,6 +160,7 @@ class EveryDreamValidator:
stabilize_split_proportion = self.config['stabilize_split_proportion'] stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer, stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train') name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)

View File

@ -263,7 +263,7 @@ class ImageTrainItem:
self.multiplier = multiplier self.multiplier = multiplier
self.image_size = None self.image_size = None
if image is None: if image is None or len(image) == 0:
self.image = [] self.image = []
else: else:
self.image = image self.image = image

View File

@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
with open(multiply_txt_path, 'r') as f: with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip()) val = float(f.read().strip())
multipliers[current_dir] = val multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e: except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0 multipliers[current_dir] = 1.0
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
caption = ImageCaption.resolve(pathname) caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
items.append(item)
cur_file_multiplier = multipliers[current_dir]
while cur_file_multiplier >= 1.0:
items.append(item)
cur_file_multiplier -= 1
if cur_file_multiplier > 0:
if randomizer.random() < cur_file_multiplier:
items.append(item)
return items return items
@staticmethod @staticmethod

View File

@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU if torch.cuda.is_available():
from utils.gpu import GPU
import data.aspects as aspects import data.aspects as aspects
import data.resolver as resolver import data.resolver as resolver
@ -159,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
""" """
updates the vram usage for the epoch updates the vram usage for the epoch
""" """
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() if gpu is not None:
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step) gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
epoch_mem_color = Style.RESET_ALL log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
if gpu_used_mem > 0.93 * gpu_total_mem: epoch_mem_color = Style.RESET_ALL
epoch_mem_color = Fore.LIGHTRED_EX if gpu_used_mem > 0.93 * gpu_total_mem:
elif gpu_used_mem > 0.85 * gpu_total_mem: epoch_mem_color = Fore.LIGHTRED_EX
epoch_mem_color = Fore.LIGHTYELLOW_EX elif gpu_used_mem > 0.85 * gpu_total_mem:
elif gpu_used_mem > 0.7 * gpu_total_mem: epoch_mem_color = Fore.LIGHTYELLOW_EX
epoch_mem_color = Fore.LIGHTGREEN_EX elif gpu_used_mem > 0.7 * gpu_total_mem:
elif gpu_used_mem < 0.5 * gpu_total_mem: epoch_mem_color = Fore.LIGHTGREEN_EX
epoch_mem_color = Fore.LIGHTBLUE_EX elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None: if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args): def set_args_12gb(args):
@ -326,8 +328,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
# Remove erroneous items # Remove erroneous items
image_train_items = [item for item in resolved_items if item.error is None] image_train_items = [item for item in resolved_items if item.error is None]
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
return image_train_items return image_train_items
@ -372,6 +373,7 @@ def main(args):
else: else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu' device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -548,6 +550,7 @@ def main(args):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
logging.error(" * Failed to load checkpoint *") logging.error(" * Failed to load checkpoint *")
raise
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
@ -620,9 +623,13 @@ def main(args):
image_train_items = resolve_image_train_items(args, log_folder) image_train_items = resolve_image_train_items(args, log_folder)
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) validator = EveryDreamValidator(args.validation_config,
default_batch_size=args.batch_size,
resolution=args.resolution,
log_writer=log_writer,
)
# the validation dataset may need to steal some items from image_train_items # the validation dataset may need to steal some items from image_train_items
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect( data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items, image_train_items=image_train_items,
@ -710,8 +717,9 @@ def main(args):
if not os.path.exists(f"{log_folder}/samples/"): if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/") os.makedirs(f"{log_folder}/samples/")
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() if gpu is not None:
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
@ -940,7 +948,7 @@ def main(args):
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate # validate
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect() gc.collect()
# end of epoch # end of epoch
@ -1021,6 +1029,7 @@ if __name__ == "__main__":
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")