Move image resolution into its own function

This commit is contained in:
Joel Holdbrooks 2023-01-29 18:20:40 -08:00
parent 56f130c027
commit f96d44ddb4
1 changed files with 17 additions and 14 deletions

View File

@ -296,8 +296,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None:
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
@ -315,6 +314,21 @@ def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) ->
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
undersized_images_file.write(message)
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
logging.info(" Preloading images...")
resolved_items = resolver.resolve(args.data_root, args)
report_image_train_item_problems(log_folder, resolved_items)
image_paths = set(map(lambda item: item.pathname, resolved_items))
# Remove erroneous items
image_train_items = [item for item in resolved_items if item.error is None]
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
return image_train_items
def write_batch_schedule(log_folder, train_batch, epoch):
if args.write_schedule:
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
@ -349,8 +363,6 @@ def main(args):
if not os.path.exists(log_folder):
os.makedirs(log_folder)
args.log_folder = log_folder
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
@ -586,15 +598,7 @@ def main(args):
log_optimizer(optimizer, betas, epsilon)
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
logging.info(" Preloading images...")
image_train_items = resolver.resolve(args.data_root, args)
report_image_train_item_problems(log_folder, image_train_items)
image_paths = set(map(lambda item: item.pathname, image_train_items))
# Remove erroneous items
image_train_items = [item for item in image_train_items if item.error is None]
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
image_train_items = resolve_image_train_items(args, log_folder)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
@ -958,7 +962,6 @@ def main(args):
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args):
"""
Update old args to new args to deal with json config loading and missing args for compatibility