diff --git a/train.py b/train.py index e11ab99..48a1aaf 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ import math import signal import argparse import logging +import threading import time import gc import random @@ -348,6 +349,29 @@ def read_sample_prompts(sample_prompts_file_path: str): return sample_prompts + +def collate_fn(batch): + """ + Collates batches + """ + images = [example["image"] for example in batch] + captions = [example["caption"] for example in batch] + tokens = [example["tokens"] for example in batch] + runt_size = batch[0]["runt_size"] + + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + + ret = { + "tokens": torch.stack(tuple(tokens)), + "image": images, + "captions": captions, + "runt_size": runt_size, + } + del batch + return ret + + def main(args): """ Main entry point @@ -676,19 +700,24 @@ def main(args): """ handles sigterm """ - global interrupted - if not interrupted: - interrupted=True - global global_step - #TODO: save model on ctrl-c - interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}") - print() - logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") - logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") - logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") - time.sleep(2) # give opportunity to ctrl-C again to cancel save - __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) - exit(_SIGTERM_EXIT_CODE) + is_main_thread = (torch.utils.data.get_worker_info() == None) + if is_main_thread: + global interrupted + if not interrupted: + interrupted=True + global global_step + #TODO: save model on ctrl-c + interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}") + print() + logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") + logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") + logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") + time.sleep(2) # give opportunity to ctrl-C again to cancel save + __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + exit(_SIGTERM_EXIT_CODE) + else: + # non-main threads (i.e. dataloader workers) should exit cleanly + exit(0) signal.signal(signal.SIGINT, sigterm_handler) @@ -701,33 +730,13 @@ def main(args): logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - def collate_fn(batch): - """ - Collates batches - """ - images = [example["image"] for example in batch] - captions = [example["caption"] for example in batch] - tokens = [example["tokens"] for example in batch] - runt_size = batch[0]["runt_size"] - - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - - ret = { - "tokens": torch.stack(tuple(tokens)), - "image": images, - "captions": captions, - "runt_size": runt_size, - } - del batch - return ret - train_dataloader = torch.utils.data.DataLoader( train_batch, batch_size=args.batch_size, shuffle=False, - num_workers=0, - collate_fn=collate_fn + num_workers=4, + collate_fn=collate_fn, + pin_memory=True ) unet.train() if not args.disable_unet_training else unet.eval()