diff --git a/train.py b/train.py index 1ad92d6..a31f9a0 100644 --- a/train.py +++ b/train.py @@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) + +def read_sample_prompts(sample_prompts_file_path: str): + sample_prompts = [] + with open(sample_prompts_file_path, "r") as f: + for line in f: + sample_prompts.append(line.strip()) + return sample_prompts + + def main(args): """ Main entry point @@ -580,12 +589,6 @@ def main(args): num_training_steps=args.lr_decay_steps, ) - sample_prompts = [] - with open(args.sample_prompts, "r") as f: - for line in f: - sample_prompts.append(line.strip()) - - if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) @@ -841,6 +844,7 @@ def main(args): pipe = pipe.to(device) with torch.no_grad(): + sample_prompts = read_sample_prompts(args.sample_prompts) if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: __generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution) else: