Move image resolution into its own function
This commit is contained in:
parent
56f130c027
commit
f96d44ddb4
31
train.py
31
train.py
|
@ -296,8 +296,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
|||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
|
||||
def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None:
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
|
@ -315,6 +314,21 @@ def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) ->
|
|||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
resolved_items = resolver.resolve(args.data_root, args)
|
||||
report_image_train_item_problems(log_folder, resolved_items)
|
||||
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
||||
|
||||
# Remove erroneous items
|
||||
image_train_items = [item for item in resolved_items if item.error is None]
|
||||
|
||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
return image_train_items
|
||||
|
||||
def write_batch_schedule(log_folder, train_batch, epoch):
|
||||
if args.write_schedule:
|
||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
|
@ -350,8 +364,6 @@ def main(args):
|
|||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
args.log_folder = log_folder
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
"""
|
||||
|
@ -586,15 +598,7 @@ def main(args):
|
|||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
image_train_items = resolver.resolve(args.data_root, args)
|
||||
report_image_train_item_problems(log_folder, image_train_items)
|
||||
image_paths = set(map(lambda item: item.pathname, image_train_items))
|
||||
# Remove erroneous items
|
||||
image_train_items = [item for item in image_train_items if item.error is None]
|
||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
|
@ -958,7 +962,6 @@ def main(args):
|
|||
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
||||
|
||||
def update_old_args(t_args):
|
||||
"""
|
||||
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||
|
|
Loading…
Reference in New Issue