re-read sample prompts every time they're generated
This commit is contained in:
parent
bc273d0512
commit
3aa9139b4b
16
train.py
16
train.py
|
@ -289,6 +289,15 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(100)
|
scaler.set_growth_interval(100)
|
||||||
|
|
||||||
|
|
||||||
|
def read_sample_prompts(sample_prompts_file_path: str):
|
||||||
|
sample_prompts = []
|
||||||
|
with open(sample_prompts_file_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
sample_prompts.append(line.strip())
|
||||||
|
return sample_prompts
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
Main entry point
|
Main entry point
|
||||||
|
@ -580,12 +589,6 @@ def main(args):
|
||||||
num_training_steps=args.lr_decay_steps,
|
num_training_steps=args.lr_decay_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_prompts = []
|
|
||||||
with open(args.sample_prompts, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
sample_prompts.append(line.strip())
|
|
||||||
|
|
||||||
|
|
||||||
if args.wandb is not None and args.wandb:
|
if args.wandb is not None and args.wandb:
|
||||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||||
|
|
||||||
|
@ -841,6 +844,7 @@ def main(args):
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
sample_prompts = read_sample_prompts(args.sample_prompts)
|
||||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||||
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
|
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue