diff --git a/doc/ATWEAKING.md b/doc/ATWEAKING.md index f2c8523..a8d91da 100644 --- a/doc/ATWEAKING.md +++ b/doc/ATWEAKING.md @@ -53,7 +53,7 @@ Cosine also has a decay period to define how long it takes to get to zero LR as ## Gradient accumulation -Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model update. +Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model as an update to weights. Example: diff --git a/train.py b/train.py index 1a31da2..d66e49b 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,12 @@ _GRAD_ACCUM_STEPS = 1 # future use... _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 +def clean_filename(filename): + """ + removes all non-alphanumeric characters from a string so it is safe to use as a filename + """ + return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip() + def convert_to_hf(ckpt_path): hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path)) @@ -97,6 +103,9 @@ def setup_local_logger(args): return datetimestamp def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon): + """ + logs the optimizer settings + """ logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}") @@ -290,14 +299,16 @@ def main(args): result.paste(image, (x_offset, 0)) x_offset += image.width - result.save(f"{log_folder}/samples/gs{gs:05}-{prompt[:100]}.png") + clean_prompt = clean_filename(prompt) + + result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png") tfimage = transforms.ToTensor()(result) if random_captions: log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs) i += 1 else: - log_writer.add_image(tag=f"sample_{prompt[:150]}", img_tensor=tfimage, global_step=gs) + log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs) del result del tfimage @@ -543,6 +554,10 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() + if global_step > 0 and global_step > args.text_encoder_steps == 0: + text_encoder.requires_grad_(False) + text_encoder.eval() + with torch.no_grad(): with autocast(): pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) @@ -709,6 +724,7 @@ if __name__ == "__main__": argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") + argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)") args = argparser.parse_args() main(args)