fix sample filenames due to illegal characters

This commit is contained in:
Victor Hall 2022-12-19 14:43:10 -05:00
parent b0a23ebd73
commit 07dbf64ecf
2 changed files with 19 additions and 3 deletions

View File

@ -53,7 +53,7 @@ Cosine also has a decay period to define how long it takes to get to zero LR as
## Gradient accumulation ## Gradient accumulation
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model update. Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model as an update to weights.
Example: Example:

View File

@ -53,6 +53,12 @@ _GRAD_ACCUM_STEPS = 1 # future use...
_SIGTERM_EXIT_CODE = 130 _SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9 _VERY_LARGE_NUMBER = 1e9
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
def convert_to_hf(ckpt_path): def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path)) hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
@ -97,6 +103,9 @@ def setup_local_logger(args):
return datetimestamp return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon): def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}") logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
@ -290,14 +299,16 @@ def main(args):
result.paste(image, (x_offset, 0)) result.paste(image, (x_offset, 0))
x_offset += image.width x_offset += image.width
result.save(f"{log_folder}/samples/gs{gs:05}-{prompt[:100]}.png") clean_prompt = clean_filename(prompt)
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
tfimage = transforms.ToTensor()(result) tfimage = transforms.ToTensor()(result)
if random_captions: if random_captions:
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs) log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
i += 1 i += 1
else: else:
log_writer.add_image(tag=f"sample_{prompt[:150]}", img_tensor=tfimage, global_step=gs) log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
del result del result
del tfimage del tfimage
@ -543,6 +554,10 @@ def main(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
step_start_time = time.time() step_start_time = time.time()
if global_step > 0 and global_step > args.text_encoder_steps == 0:
text_encoder.requires_grad_(False)
text_encoder.eval()
with torch.no_grad(): with torch.no_grad():
with autocast(): with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
@ -709,6 +724,7 @@ if __name__ == "__main__":
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
args = argparser.parse_args() args = argparser.parse_args()
main(args) main(args)