re-read sample prompts every time they're generated
This commit is contained in:
parent
bc273d0512
commit
3aa9139b4b
16
train.py
16
train.py
|
@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
|||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
|
||||
def read_sample_prompts(sample_prompts_file_path: str):
|
||||
sample_prompts = []
|
||||
with open(sample_prompts_file_path, "r") as f:
|
||||
for line in f:
|
||||
sample_prompts.append(line.strip())
|
||||
return sample_prompts
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -580,12 +589,6 @@ def main(args):
|
|||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
sample_prompts = []
|
||||
with open(args.sample_prompts, "r") as f:
|
||||
for line in f:
|
||||
sample_prompts.append(line.strip())
|
||||
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
|
||||
|
@ -841,6 +844,7 @@ def main(args):
|
|||
pipe = pipe.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_prompts = read_sample_prompts(args.sample_prompts)
|
||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue