amp mode work

This commit is contained in:
Victor Hall 2023-01-16 15:48:06 -05:00
parent 6ea55a1057
commit 879f5bf33d
2 changed files with 41 additions and 54 deletions

View File

@ -230,7 +230,7 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
if args.ed1_mode and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True")
@ -272,9 +272,6 @@ def setup_args(args):
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
args.clip_grad_norm = 1.0
if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
@ -475,22 +472,13 @@ def main(args):
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
d_type = torch.float32
if args.mixed_precision == "fp16":
d_type = torch.float16
logging.info(" * Using fp16 *")
args.amp = True
elif args.mixed_precision == "bf16":
d_type = torch.bfloat16
logging.info(" * Using bf16 *")
args.amp = True
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
text_encoder = text_encoder.to(device, dtype=torch.float16)
else:
logging.info(" * Using FP32 *")
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
unet = unet.to(device, dtype=d_type)
text_encoder = text_encoder.to(device, dtype=d_type)
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
@ -504,7 +492,7 @@ def main(args):
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp or args.mix_precision == "fp16":
if args.amp:
epsilon = 1e-8
weight_decay = 0.01
@ -666,17 +654,18 @@ def main(args):
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
if args.amp or d_type != torch.float32:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
enabled=False,
enabled=args.amp,
#enabled=True,
init_scale=2048.0,
init_scale=2**17.5,
growth_factor=1.5,
backoff_factor=0.707,
backoff_factor=1.0/1.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
@ -741,7 +730,7 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp or d_type != torch.float32):
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
@ -750,10 +739,10 @@ def main(args):
del target, model_pred
if args.amp:
#if args.amp:
scaler.scale(loss).backward()
else:
loss.backward()
#else:
# loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
@ -773,11 +762,11 @@ def main(args):
param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
if args.amp and d_type == torch.float32:
# if args.amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
# else:
# optimizer.step()
optimizer.zero_grad(set_to_none=True)
lr_scheduler.step()
@ -840,6 +829,13 @@ def main(args):
del batch
global_step += 1
if global_step == 500:
scaler.set_growth_factor(1.35)
scaler.set_backoff_factor(1/1.35)
if global_step == 1000:
scaler.set_growth_factor(1.2)
scaler.set_backoff_factor(1/1.2)
# end of step
steps_pbar.close()
@ -893,9 +889,6 @@ def update_old_args(t_args):
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "mixed_precision"):
print(f" Config json is missing 'mixed_precision' flag")
t_args.__dict__["mixed_precision"] = "fp32"
if not hasattr(t_args, "rated_dataset"):
print(f" Config json is missing 'rated_dataset' flag")
t_args.__dict__["rated_dataset"] = False
@ -920,7 +913,6 @@ if __name__ == "__main__":
update_old_args(t_args) # update args to support older configs
print(t_args.__dict__)
args = argparser.parse_args(namespace=t_args)
print(f"mixed_precision: {args.mixed_precision}")
else:
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
@ -947,7 +939,6 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)

View File

@ -25,13 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
with open(unet_cfg_path, "r") as f:
unet_cfg = json.load(f)
if force_sd1attn:
if low_vram:
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
else:
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
else:
else: # SD 2 attn for xformers
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")