amp mode work
This commit is contained in:
parent
6ea55a1057
commit
879f5bf33d
65
train.py
65
train.py
|
@ -230,7 +230,7 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
|
||||
if args.ed1_mode and not args.disable_xformers:
|
||||
args.disable_xformers = True
|
||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
||||
|
||||
|
@ -272,9 +272,6 @@ def setup_args(args):
|
|||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
|
||||
args.clip_grad_norm = 1.0
|
||||
|
||||
if args.rated_dataset:
|
||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||
|
||||
|
@ -475,22 +472,13 @@ def main(args):
|
|||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
d_type = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
d_type = torch.float16
|
||||
logging.info(" * Using fp16 *")
|
||||
args.amp = True
|
||||
elif args.mixed_precision == "bf16":
|
||||
d_type = torch.bfloat16
|
||||
logging.info(" * Using bf16 *")
|
||||
args.amp = True
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float16)
|
||||
else:
|
||||
logging.info(" * Using FP32 *")
|
||||
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
|
||||
unet = unet.to(device, dtype=d_type)
|
||||
text_encoder = text_encoder.to(device, dtype=d_type)
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
|
@ -504,7 +492,7 @@ def main(args):
|
|||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp or args.mix_precision == "fp16":
|
||||
if args.amp:
|
||||
epsilon = 1e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
|
@ -666,17 +654,18 @@ def main(args):
|
|||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
if args.amp or d_type != torch.float32:
|
||||
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
enabled=False,
|
||||
enabled=args.amp,
|
||||
#enabled=True,
|
||||
init_scale=2048.0,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=1.5,
|
||||
backoff_factor=0.707,
|
||||
backoff_factor=1.0/1.5,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
@ -741,7 +730,7 @@ def main(args):
|
|||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp or d_type != torch.float32):
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
|
@ -750,10 +739,10 @@ def main(args):
|
|||
|
||||
del target, model_pred
|
||||
|
||||
if args.amp:
|
||||
#if args.amp:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
#else:
|
||||
# loss.backward()
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
|
@ -773,11 +762,11 @@ def main(args):
|
|||
param.grad *= grad_scale
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
if args.amp and d_type == torch.float32:
|
||||
# if args.amp:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
# else:
|
||||
# optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
lr_scheduler.step()
|
||||
|
@ -840,6 +829,13 @@ def main(args):
|
|||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
||||
if global_step == 500:
|
||||
scaler.set_growth_factor(1.35)
|
||||
scaler.set_backoff_factor(1/1.35)
|
||||
if global_step == 1000:
|
||||
scaler.set_growth_factor(1.2)
|
||||
scaler.set_backoff_factor(1/1.2)
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -893,9 +889,6 @@ def update_old_args(t_args):
|
|||
if not hasattr(t_args, "disable_unet_training"):
|
||||
print(f" Config json is missing 'disable_unet_training' flag")
|
||||
t_args.__dict__["disable_unet_training"] = False
|
||||
if not hasattr(t_args, "mixed_precision"):
|
||||
print(f" Config json is missing 'mixed_precision' flag")
|
||||
t_args.__dict__["mixed_precision"] = "fp32"
|
||||
if not hasattr(t_args, "rated_dataset"):
|
||||
print(f" Config json is missing 'rated_dataset' flag")
|
||||
t_args.__dict__["rated_dataset"] = False
|
||||
|
@ -920,7 +913,6 @@ if __name__ == "__main__":
|
|||
update_old_args(t_args) # update args to support older configs
|
||||
print(t_args.__dict__)
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
print(f"mixed_precision: {args.mixed_precision}")
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
|
@ -947,7 +939,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
|
|
|
@ -25,13 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
|||
with open(unet_cfg_path, "r") as f:
|
||||
unet_cfg = json.load(f)
|
||||
|
||||
|
||||
if force_sd1attn:
|
||||
if low_vram:
|
||||
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
|
||||
else:
|
||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||
else:
|
||||
else: # SD 2 attn for xformers
|
||||
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
||||
|
||||
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
||||
|
|
Loading…
Reference in New Issue