Merge pull request #37 from damian0815/feat_reload_sample_prompts_every_generation

re-read sample prompts txt file every generation
This commit is contained in:
Victor Hall 2023-01-31 11:07:08 -05:00 committed by GitHub
commit 246b57c3c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 6 deletions

View File

@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def read_sample_prompts(sample_prompts_file_path: str):
sample_prompts = []
with open(sample_prompts_file_path, "r") as f:
for line in f:
sample_prompts.append(line.strip())
return sample_prompts
def main(args):
"""
Main entry point
@ -580,12 +589,6 @@ def main(args):
num_training_steps=args.lr_decay_steps,
)
sample_prompts = []
with open(args.sample_prompts, "r") as f:
for line in f:
sample_prompts.append(line.strip())
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
@ -841,6 +844,7 @@ def main(args):
pipe = pipe.to(device)
with torch.no_grad():
sample_prompts = read_sample_prompts(args.sample_prompts)
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
else: