fix sample filenames due to illegal characters

This commit is contained in:
Victor Hall 2022-12-19 14:43:10 -05:00
parent b0a23ebd73
commit 07dbf64ecf
2 changed files with 19 additions and 3 deletions

View File

@ -53,7 +53,7 @@ Cosine also has a decay period to define how long it takes to get to zero LR as
## Gradient accumulation
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model update.
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model as an update to weights.
Example:

View File

@ -53,6 +53,12 @@ _GRAD_ACCUM_STEPS = 1 # future use...
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
@ -97,6 +103,9 @@ def setup_local_logger(args):
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
@ -290,14 +299,16 @@ def main(args):
result.paste(image, (x_offset, 0))
x_offset += image.width
result.save(f"{log_folder}/samples/gs{gs:05}-{prompt[:100]}.png")
clean_prompt = clean_filename(prompt)
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
tfimage = transforms.ToTensor()(result)
if random_captions:
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
i += 1
else:
log_writer.add_image(tag=f"sample_{prompt[:150]}", img_tensor=tfimage, global_step=gs)
log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
del result
del tfimage
@ -543,6 +554,10 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
if global_step > 0 and global_step > args.text_encoder_steps == 0:
text_encoder.requires_grad_(False)
text_encoder.eval()
with torch.no_grad():
with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
@ -709,6 +724,7 @@ if __name__ == "__main__":
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
args = argparser.parse_args()
main(args)