From 50a71e63b61a014868ea9914df5f625a2e5fa263 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 11:26:11 +0100 Subject: [PATCH 1/5] background load images for a 40% performance improvement --- train.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/train.py b/train.py index e11ab99..018d851 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ import math import signal import argparse import logging +import threading import time import gc import random @@ -676,18 +677,19 @@ def main(args): """ handles sigterm """ - global interrupted - if not interrupted: - interrupted=True - global global_step - #TODO: save model on ctrl-c - interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}") - print() - logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") - logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") - logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") - time.sleep(2) # give opportunity to ctrl-C again to cancel save - __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + if threading.current_thread().__class__.__name__ == '_MainThread': + global interrupted + if not interrupted: + interrupted=True + global global_step + #TODO: save model on ctrl-c + interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}") + print() + logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") + logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") + logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") + time.sleep(2) # give opportunity to ctrl-C again to cancel save + __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) exit(_SIGTERM_EXIT_CODE) signal.signal(signal.SIGINT, sigterm_handler) @@ -726,8 +728,9 @@ def main(args): train_batch, batch_size=args.batch_size, shuffle=False, - num_workers=0, - collate_fn=collate_fn + num_workers=1, + collate_fn=collate_fn, + pin_memory=True ) unet.train() if not args.disable_unet_training else unet.eval() From 86a500409847bd54bd3a8498545e768017106343 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 11:32:39 +0100 Subject: [PATCH 2/5] better main-thread detection --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 018d851..0bbcf39 100644 --- a/train.py +++ b/train.py @@ -677,7 +677,8 @@ def main(args): """ handles sigterm """ - if threading.current_thread().__class__.__name__ == '_MainThread': + is_main_thread = (torch.utils.data.get_worker_info() == None) + if is_main_thread: global interrupted if not interrupted: interrupted=True From 4b5654452c50d2311ba77d542b2f58e1282665da Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 14:13:01 +0100 Subject: [PATCH 3/5] more workers --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 0bbcf39..5a0e20e 100644 --- a/train.py +++ b/train.py @@ -729,7 +729,7 @@ def main(args): train_batch, batch_size=args.batch_size, shuffle=False, - num_workers=1, + num_workers=4, collate_fn=collate_fn, pin_memory=True ) From c270dbf6a89a72eb9047e71083180ce461e36095 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 14:19:18 +0100 Subject: [PATCH 4/5] ensure dataloader workers exit cleanly on ctrl-c --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 5a0e20e..6d3c157 100644 --- a/train.py +++ b/train.py @@ -691,7 +691,10 @@ def main(args): logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") time.sleep(2) # give opportunity to ctrl-C again to cancel save __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) - exit(_SIGTERM_EXIT_CODE) + exit(_SIGTERM_EXIT_CODE) + else: + # non-main threads (i.e. dataloader workers) should exit cleanly + exit(0) signal.signal(signal.SIGINT, sigterm_handler) From 21a64c38f2000932dbaa0db8c498dea591fa9737 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 31 Jan 2023 13:46:58 +0100 Subject: [PATCH 5/5] move collate_fn to top level to possibly fix windows issue --- train.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/train.py b/train.py index 6d3c157..48a1aaf 100644 --- a/train.py +++ b/train.py @@ -349,6 +349,29 @@ def read_sample_prompts(sample_prompts_file_path: str): return sample_prompts + +def collate_fn(batch): + """ + Collates batches + """ + images = [example["image"] for example in batch] + captions = [example["caption"] for example in batch] + tokens = [example["tokens"] for example in batch] + runt_size = batch[0]["runt_size"] + + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + + ret = { + "tokens": torch.stack(tuple(tokens)), + "image": images, + "captions": captions, + "runt_size": runt_size, + } + del batch + return ret + + def main(args): """ Main entry point @@ -707,27 +730,6 @@ def main(args): logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - def collate_fn(batch): - """ - Collates batches - """ - images = [example["image"] for example in batch] - captions = [example["caption"] for example in batch] - tokens = [example["tokens"] for example in batch] - runt_size = batch[0]["runt_size"] - - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - - ret = { - "tokens": torch.stack(tuple(tokens)), - "image": images, - "captions": captions, - "runt_size": runt_size, - } - del batch - return ret - train_dataloader = torch.utils.data.DataLoader( train_batch, batch_size=args.batch_size,