re-read sample prompts every time they're generated

This commit is contained in:
Damian Stewart 2023-01-30 00:43:16 +01:00
parent bc273d0512
commit 3aa9139b4b
1 changed files with 10 additions and 6 deletions

View File

@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100) scaler.set_growth_interval(100)
def read_sample_prompts(sample_prompts_file_path: str):
sample_prompts = []
with open(sample_prompts_file_path, "r") as f:
for line in f:
sample_prompts.append(line.strip())
return sample_prompts
def main(args): def main(args):
""" """
Main entry point Main entry point
@ -580,12 +589,6 @@ def main(args):
num_training_steps=args.lr_decay_steps, num_training_steps=args.lr_decay_steps,
) )
sample_prompts = []
with open(args.sample_prompts, "r") as f:
for line in f:
sample_prompts.append(line.strip())
if args.wandb is not None and args.wandb: if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, ) wandb.init(project=args.project_name, sync_tensorboard=True, )
@ -841,6 +844,7 @@ def main(args):
pipe = pipe.to(device) pipe = pipe.to(device)
with torch.no_grad(): with torch.no_grad():
sample_prompts = read_sample_prompts(args.sample_prompts)
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution) __generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
else: else: