diff --git a/train.py b/train.py index db4b908..af0d1ee 100644 --- a/train.py +++ b/train.py @@ -296,8 +296,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) - -def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None: +def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None: for item in items: if item.error is not None: logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") @@ -315,6 +314,21 @@ def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n" undersized_images_file.write(message) +def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]: + logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") + logging.info(" Preloading images...") + + resolved_items = resolver.resolve(args.data_root, args) + report_image_train_item_problems(log_folder, resolved_items) + image_paths = set(map(lambda item: item.pathname, resolved_items)) + + # Remove erroneous items + image_train_items = [item for item in resolved_items if item.error is None] + + print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + + return image_train_items + def write_batch_schedule(log_folder, train_batch, epoch): if args.write_schedule: with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: @@ -349,8 +363,6 @@ def main(args): if not os.path.exists(log_folder): os.makedirs(log_folder) - - args.log_folder = log_folder @torch.no_grad() def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): @@ -586,15 +598,7 @@ def main(args): log_optimizer(optimizer, betas, epsilon) - logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") - logging.info(" Preloading images...") - - image_train_items = resolver.resolve(args.data_root, args) - report_image_train_item_problems(log_folder, image_train_items) - image_paths = set(map(lambda item: item.pathname, image_train_items)) - # Remove erroneous items - image_train_items = [item for item in image_train_items if item.error is None] - print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + image_train_items = resolve_image_train_items(args, log_folder) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -958,7 +962,6 @@ def main(args): logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") - def update_old_args(t_args): """ Update old args to new args to deal with json config loading and missing args for compatibility