fix sample filenames due to illegal characters
This commit is contained in:
parent
b0a23ebd73
commit
07dbf64ecf
|
@ -53,7 +53,7 @@ Cosine also has a decay period to define how long it takes to get to zero LR as
|
|||
|
||||
## Gradient accumulation
|
||||
|
||||
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model update.
|
||||
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model as an update to weights.
|
||||
|
||||
Example:
|
||||
|
||||
|
|
20
train.py
20
train.py
|
@ -53,6 +53,12 @@ _GRAD_ACCUM_STEPS = 1 # future use...
|
|||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
"""
|
||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||
|
||||
def convert_to_hf(ckpt_path):
|
||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
|
||||
|
@ -97,6 +103,9 @@ def setup_local_logger(args):
|
|||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
||||
|
||||
|
@ -290,14 +299,16 @@ def main(args):
|
|||
result.paste(image, (x_offset, 0))
|
||||
x_offset += image.width
|
||||
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{prompt[:100]}.png")
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if random_captions:
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
i += 1
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{prompt[:150]}", img_tensor=tfimage, global_step=gs)
|
||||
log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
|
||||
|
||||
del result
|
||||
del tfimage
|
||||
|
@ -543,6 +554,10 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
if global_step > 0 and global_step > args.text_encoder_steps == 0:
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast():
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -709,6 +724,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
|
||||
args = argparser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue