diff --git a/Train_Colab.ipynb b/Train_Colab.ipynb index c54448e..553cba1 100644 --- a/Train_Colab.ipynb +++ b/Train_Colab.ipynb @@ -101,7 +101,7 @@ "!pip install -q protobuf==3.20.1\n", "!pip install -q wandb==0.13.6\n", "!pip install -q pyre-extensions==0.0.23\n", - "!pip install -q xformers==0.0.17.dev435\n", + "!pip install -q xformers==0.0.16\n", "!pip install -q pytorch-lightning==1.6.5\n", "!pip install -q OmegaConf==2.2.3\n", "!pip install -q numpy==1.23.5\n", diff --git a/data/data_loader.py b/data/data_loader.py index 5fe4ba6..d97e38e 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. """ import bisect -from functools import reduce +import logging +import os.path +from collections import defaultdict import math -import copy import random -from data.image_train_item import ImageTrainItem, ImageCaption -import PIL +from data.image_train_item import ImageTrainItem +import PIL.Image PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default @@ -38,43 +39,40 @@ class DataLoaderMultiAspect(): self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) + expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) + if expected_epoch_size != len(self.prepared_train_data): + logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.") + else: + logging.info(f" * DLMA initialized with {len(image_train_items)} images.") + self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] self.__update_rating_sums() - def __pick_multiplied_set(self, randomizer): + + def __pick_multiplied_set(self, randomizer: random.Random): """ Deals with multiply.txt whole and fractional numbers """ - #print(f"Picking multiplied set from {len(self.prepared_train_data)}") - data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property - epoch_size = len(self.prepared_train_data) picked_images = [] - - # add by whole number part first and decrement multiplier in copy - for iti in data_copy: - #print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}") - while iti.multiplier >= 1.0: + fractional_images_per_directory = defaultdict(list[ImageTrainItem]) + for iti in self.prepared_train_data: + multiplier = iti.multiplier + while multiplier >= 1: picked_images.append(iti) - #print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}") - iti.multiplier -= 1.0 + multiplier -= 1 + # fractional remainders must be dealt with separately + if multiplier > 0: + directory = os.path.dirname(iti.pathname) + fractional_images_per_directory[directory].append(iti) - remaining = epoch_size - len(picked_images) + # resolve fractional parts per-directory + for _, fractional_items in fractional_images_per_directory.items(): + randomizer.shuffle(fractional_items) + multiplier = fractional_items[0].multiplier % 1.0 + count_to_take = math.ceil(multiplier * len(fractional_items)) + picked_images.extend(fractional_items[:count_to_take]) - assert remaining >= 0, "Something went wrong with the multiplier calculation" - - # add by remaining fractional numbers by random chance - while remaining > 0: - for iti in data_copy: - if randomizer.uniform(0.0, 1.0) < iti.multiplier: - #print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}") - picked_images.append(iti) - remaining -= 1 - iti.multiplier = 0.0 - if remaining <= 0: - break - - del data_copy return picked_images def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]: @@ -110,20 +108,19 @@ class DataLoaderMultiAspect(): buckets[(target_wh[0],target_wh[1])] = [] buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) - if len(buckets) > 1: - for bucket in buckets: - truncate_count = len(buckets[bucket]) % batch_size - if truncate_count > 0: - runt_bucket = buckets[bucket][-truncate_count:] - for item in runt_bucket: - item.runt_size = truncate_count - while len(runt_bucket) < batch_size: - runt_bucket.append(random.choice(runt_bucket)) + for bucket in buckets: + truncate_count = len(buckets[bucket]) % batch_size + if truncate_count > 0: + runt_bucket = buckets[bucket][-truncate_count:] + for item in runt_bucket: + item.runt_size = truncate_count + while len(runt_bucket) < batch_size: + runt_bucket.append(random.choice(runt_bucket)) - current_bucket_size = len(buckets[bucket]) + current_bucket_size = len(buckets[bucket]) - buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] - buckets[bucket].extend(runt_bucket) + buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] + buckets[bucket].extend(runt_bucket) # flatten the buckets items: list[ImageTrainItem] = [] diff --git a/data/every_dream.py b/data/every_dream.py index 30f10aa..06cdacc 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset): num_images = len(self.image_train_items) logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}") - def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: - items = self.data_loader.get_random_split(split_proportion, remove_from_dataset) - self.__update_image_train_items(1.0) - return items - - def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 1302cc8..4f2c083 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -1,7 +1,9 @@ +import copy import json +import logging import math import random -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Generator from argparse import Namespace import torch @@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch remaining_items = list(items_copy[split_item_count:]) return split_items, remaining_items +def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]: + for i in items: + yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1) class EveryDreamValidator: def __init__(self, val_config_path: Optional[str], default_batch_size: int, + resolution: int, log_writer: SummaryWriter): self.val_dataloader = None self.train_overlapping_dataloader = None self.log_writer = log_writer + self.resolution = resolution self.config = { 'batch_size': default_batch_size, 'every_n_epochs': 1, 'seed': 555, + 'validate_training': True, 'val_split_mode': 'automatic', 'val_split_proportion': 0.15, @@ -120,21 +128,24 @@ class EveryDreamValidator: def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: - val_split_mode = self.config['val_split_mode'] + val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None val_split_proportion = self.config['val_split_proportion'] remaining_train_items = image_train_items - if val_split_mode == 'none': + if val_split_mode is None or val_split_mode == 'none': return None, image_train_items elif val_split_mode == 'automatic': val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size) + val_items = list(disable_multiplier_and_flip(val_items)) + logging.info(f" * Removed {len(val_items)} images from the training set to use for validation") elif val_split_mode == 'manual': args = Namespace( - aspects=aspects.get_aspect_buckets(512), + aspects=aspects.get_aspect_buckets(self.resolution), flip_p=0.0, seed=self.seed, ) val_data_root = self.config['val_data_root'] val_items = resolver.resolve_root(val_data_root, args) + logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}") else: raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') @@ -149,6 +160,7 @@ class EveryDreamValidator: stabilize_split_proportion = self.config['stabilize_split_proportion'] stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) + stabilize_items = list(disable_multiplier_and_flip(stabilize_items)) stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer, name='stabilize-train') stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) diff --git a/data/image_train_item.py b/data/image_train_item.py index de72374..8e88612 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -263,7 +263,7 @@ class ImageTrainItem: self.multiplier = multiplier self.image_size = None - if image is None: + if image is None or len(image) == 0: self.image = [] else: self.image = image diff --git a/data/resolver.py b/data/resolver.py index fbcd076..e66a3b8 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver): with open(multiply_txt_path, 'r') as f: val = float(f.read().strip()) multipliers[current_dir] = val - logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") + logging.info(f" - multiply.txt in '{current_dir}' set to {val}") except Exception as e: logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") multipliers[current_dir] = 1.0 @@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver): caption = ImageCaption.resolve(pathname) item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) - - cur_file_multiplier = multipliers[current_dir] + items.append(item) - while cur_file_multiplier >= 1.0: - items.append(item) - cur_file_multiplier -= 1 - - if cur_file_multiplier > 0: - if randomizer.random() < cur_file_multiplier: - items.append(item) return items @staticmethod diff --git a/train.py b/train.py index dcd5ed4..c63c471 100644 --- a/train.py +++ b/train.py @@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter -from utils.gpu import GPU +if torch.cuda.is_available(): + from utils.gpu import GPU import data.aspects as aspects import data.resolver as resolver @@ -159,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs): """ updates the vram usage for the epoch """ - gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() - log_writer.add_scalar("performance/vram", gpu_used_mem, global_step) - epoch_mem_color = Style.RESET_ALL - if gpu_used_mem > 0.93 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTRED_EX - elif gpu_used_mem > 0.85 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTYELLOW_EX - elif gpu_used_mem > 0.7 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTGREEN_EX - elif gpu_used_mem < 0.5 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTBLUE_EX + if gpu is not None: + gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() + log_writer.add_scalar("performance/vram", gpu_used_mem, global_step) + epoch_mem_color = Style.RESET_ALL + if gpu_used_mem > 0.93 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTRED_EX + elif gpu_used_mem > 0.85 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTYELLOW_EX + elif gpu_used_mem > 0.7 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTGREEN_EX + elif gpu_used_mem < 0.5 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTBLUE_EX - if logs is not None: - epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") + if logs is not None: + epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") def set_args_12gb(args): @@ -326,8 +328,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list # Remove erroneous items image_train_items = [item for item in resolved_items if item.error is None] - - print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + print (f" * Found {len(image_paths)} files in '{args.data_root}'") return image_train_items @@ -372,6 +373,7 @@ def main(args): else: logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") device = 'cpu' + gpu = None log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") @@ -548,6 +550,7 @@ def main(args): except Exception as e: traceback.print_exc() logging.error(" * Failed to load checkpoint *") + raise if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -620,9 +623,13 @@ def main(args): image_train_items = resolve_image_train_items(args, log_folder) - #validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) + validator = EveryDreamValidator(args.validation_config, + default_batch_size=args.batch_size, + resolution=args.resolution, + log_writer=log_writer, + ) # the validation dataset may need to steal some items from image_train_items - #image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) + image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -710,8 +717,9 @@ def main(args): if not os.path.exists(f"{log_folder}/samples/"): os.makedirs(f"{log_folder}/samples/") - gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() - logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") + if gpu is not None: + gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() + logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") @@ -940,7 +948,7 @@ def main(args): log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) # validate - #validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) gc.collect() # end of epoch @@ -1021,6 +1029,7 @@ if __name__ == "__main__": argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") + argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")