fix sample filenames due to illegal characters
This commit is contained in:
parent
b0a23ebd73
commit
07dbf64ecf
|
@ -53,7 +53,7 @@ Cosine also has a decay period to define how long it takes to get to zero LR as
|
||||||
|
|
||||||
## Gradient accumulation
|
## Gradient accumulation
|
||||||
|
|
||||||
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model update.
|
Gradient accumulation is sort of like a virtual batch size increase, averaging the learning over more than one step (batch) before applying it to the model as an update to weights.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|
20
train.py
20
train.py
|
@ -53,6 +53,12 @@ _GRAD_ACCUM_STEPS = 1 # future use...
|
||||||
_SIGTERM_EXIT_CODE = 130
|
_SIGTERM_EXIT_CODE = 130
|
||||||
_VERY_LARGE_NUMBER = 1e9
|
_VERY_LARGE_NUMBER = 1e9
|
||||||
|
|
||||||
|
def clean_filename(filename):
|
||||||
|
"""
|
||||||
|
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||||
|
"""
|
||||||
|
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||||
|
|
||||||
def convert_to_hf(ckpt_path):
|
def convert_to_hf(ckpt_path):
|
||||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||||
|
|
||||||
|
@ -97,6 +103,9 @@ def setup_local_logger(args):
|
||||||
return datetimestamp
|
return datetimestamp
|
||||||
|
|
||||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
||||||
|
"""
|
||||||
|
logs the optimizer settings
|
||||||
|
"""
|
||||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||||
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
@ -290,14 +299,16 @@ def main(args):
|
||||||
result.paste(image, (x_offset, 0))
|
result.paste(image, (x_offset, 0))
|
||||||
x_offset += image.width
|
x_offset += image.width
|
||||||
|
|
||||||
result.save(f"{log_folder}/samples/gs{gs:05}-{prompt[:100]}.png")
|
clean_prompt = clean_filename(prompt)
|
||||||
|
|
||||||
|
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
|
||||||
|
|
||||||
tfimage = transforms.ToTensor()(result)
|
tfimage = transforms.ToTensor()(result)
|
||||||
if random_captions:
|
if random_captions:
|
||||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
log_writer.add_image(tag=f"sample_{prompt[:150]}", img_tensor=tfimage, global_step=gs)
|
log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
|
||||||
|
|
||||||
del result
|
del result
|
||||||
del tfimage
|
del tfimage
|
||||||
|
@ -543,6 +554,10 @@ def main(args):
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if global_step > 0 and global_step > args.text_encoder_steps == 0:
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
text_encoder.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast():
|
with autocast():
|
||||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
|
@ -709,6 +724,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||||
|
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
|
||||||
args = argparser.parse_args()
|
args = argparser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue