From 3aa9139b4bb82bc2f31140016db9fb964c7e04ba Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 00:43:16 +0100 Subject: [PATCH] re-read sample prompts every time they're generated --- train.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 1ad92d6..a31f9a0 100644 --- a/train.py +++ b/train.py @@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) + +def read_sample_prompts(sample_prompts_file_path: str): + sample_prompts = [] + with open(sample_prompts_file_path, "r") as f: + for line in f: + sample_prompts.append(line.strip()) + return sample_prompts + + def main(args): """ Main entry point @@ -580,12 +589,6 @@ def main(args): num_training_steps=args.lr_decay_steps, ) - sample_prompts = [] - with open(args.sample_prompts, "r") as f: - for line in f: - sample_prompts.append(line.strip()) - - if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) @@ -841,6 +844,7 @@ def main(args): pipe = pipe.to(device) with torch.no_grad(): + sample_prompts = read_sample_prompts(args.sample_prompts) if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: __generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution) else: