Merge pull request #39 from damian0815/feat_async_image_load
Enable multi-threaded image loading for a massive perf boost
This commit is contained in:
commit
3b085bdd28
81
train.py
81
train.py
|
@ -21,6 +21,7 @@ import math
|
|||
import signal
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import gc
|
||||
import random
|
||||
|
@ -348,6 +349,29 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
|||
return sample_prompts
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -676,19 +700,24 @@ def main(args):
|
|||
"""
|
||||
handles sigterm
|
||||
"""
|
||||
global interrupted
|
||||
if not interrupted:
|
||||
interrupted=True
|
||||
global global_step
|
||||
#TODO: save model on ctrl-c
|
||||
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
||||
print()
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
is_main_thread = (torch.utils.data.get_worker_info() == None)
|
||||
if is_main_thread:
|
||||
global interrupted
|
||||
if not interrupted:
|
||||
interrupted=True
|
||||
global global_step
|
||||
#TODO: save model on ctrl-c
|
||||
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
||||
print()
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigterm_handler)
|
||||
|
||||
|
@ -701,33 +730,13 @@ def main(args):
|
|||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_batch,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
|
|
Loading…
Reference in New Issue