diff --git a/train.py b/train.py index 9ae9119..a5de0b3 100644 --- a/train.py +++ b/train.py @@ -230,7 +230,7 @@ def setup_args(args): # find the last checkpoint in the logdir args.resume_ckpt = find_last_checkpoint(args.logdir) - if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers: + if args.ed1_mode and not args.disable_xformers: args.disable_xformers = True logging.info(" ED1 mode: Overiding disable_xformers to True") @@ -272,9 +272,6 @@ def setup_args(args): if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir): os.makedirs(args.save_ckpt_dir) - if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0): - args.clip_grad_norm = 1.0 - if args.rated_dataset: args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100) @@ -475,22 +472,13 @@ def main(args): default_lr = 2e-6 curr_lr = args.lr if args.lr is not None else default_lr - d_type = torch.float32 - if args.mixed_precision == "fp16": - d_type = torch.float16 - logging.info(" * Using fp16 *") - args.amp = True - elif args.mixed_precision == "bf16": - d_type = torch.bfloat16 - logging.info(" * Using bf16 *") - args.amp = True + + vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32) + unet = unet.to(device, dtype=torch.float32) + if args.disable_textenc_training and args.amp: + text_encoder = text_encoder.to(device, dtype=torch.float16) else: - logging.info(" * Using FP32 *") - - - vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type) - unet = unet.to(device, dtype=d_type) - text_encoder = text_encoder.to(device, dtype=d_type) + text_encoder = text_encoder.to(device, dtype=torch.float32) if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") @@ -504,7 +492,7 @@ def main(args): betas = (0.9, 0.999) epsilon = 1e-8 - if args.amp or args.mix_precision == "fp16": + if args.amp: epsilon = 1e-8 weight_decay = 0.01 @@ -666,17 +654,18 @@ def main(args): logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}") - if args.amp or d_type != torch.float32: - #scaler = torch.cuda.amp.GradScaler() - scaler = torch.cuda.amp.GradScaler( - enabled=False, - #enabled=True, - init_scale=2048.0, - growth_factor=1.5, - backoff_factor=0.707, - growth_interval=50, - ) - logging.info(f" Grad scaler enabled: {scaler.is_enabled()}") + + #scaler = torch.cuda.amp.GradScaler() + scaler = torch.cuda.amp.GradScaler( + enabled=args.amp, + #enabled=True, + init_scale=2**17.5, + growth_factor=1.5, + backoff_factor=1.0/1.5, + growth_interval=50, + ) + logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") + epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") @@ -741,7 +730,7 @@ def main(args): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") del noise, latents, cuda_caption - with autocast(enabled=args.amp or d_type != torch.float32): + with autocast(enabled=args.amp): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample del timesteps, encoder_hidden_states, noisy_latents @@ -750,10 +739,10 @@ def main(args): del target, model_pred - if args.amp: - scaler.scale(loss).backward() - else: - loss.backward() + #if args.amp: + scaler.scale(loss).backward() + #else: + # loss.backward() if args.clip_grad_norm is not None: if not args.disable_unet_training: @@ -773,11 +762,11 @@ def main(args): param.grad *= grad_scale if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): - if args.amp and d_type == torch.float32: - scaler.step(optimizer) - scaler.update() - else: - optimizer.step() + # if args.amp: + scaler.step(optimizer) + scaler.update() + # else: + # optimizer.step() optimizer.zero_grad(set_to_none=True) lr_scheduler.step() @@ -840,10 +829,17 @@ def main(args): del batch global_step += 1 + + if global_step == 500: + scaler.set_growth_factor(1.35) + scaler.set_backoff_factor(1/1.35) + if global_step == 1000: + scaler.set_growth_factor(1.2) + scaler.set_backoff_factor(1/1.2) # end of step steps_pbar.close() - + elapsed_epoch_time = (time.time() - epoch_start_time) / 60 epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time)) log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step) @@ -851,7 +847,7 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) - + loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) # end of epoch @@ -893,9 +889,6 @@ def update_old_args(t_args): if not hasattr(t_args, "disable_unet_training"): print(f" Config json is missing 'disable_unet_training' flag") t_args.__dict__["disable_unet_training"] = False - if not hasattr(t_args, "mixed_precision"): - print(f" Config json is missing 'mixed_precision' flag") - t_args.__dict__["mixed_precision"] = "fp32" if not hasattr(t_args, "rated_dataset"): print(f" Config json is missing 'rated_dataset' flag") t_args.__dict__["rated_dataset"] = False @@ -920,7 +913,6 @@ if __name__ == "__main__": update_old_args(t_args) # update args to support older configs print(t_args.__dict__) args = argparser.parse_args(namespace=t_args) - print(f"mixed_precision: {args.mixed_precision}") else: print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") @@ -947,7 +939,6 @@ if __name__ == "__main__": argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") - argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions) argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) diff --git a/utils/patch_unet.py b/utils/patch_unet.py index c9392b0..4f66747 100644 --- a/utils/patch_unet.py +++ b/utils/patch_unet.py @@ -25,13 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False): with open(unet_cfg_path, "r") as f: unet_cfg = json.load(f) - if force_sd1attn: - if low_vram: - unet_cfg["attention_head_dim"] = [5, 8, 8, 8] - else: - unet_cfg["attention_head_dim"] = [8, 8, 8, 8] - else: + unet_cfg["attention_head_dim"] = [8, 8, 8, 8] + else: # SD 2 attn for xformers unet_cfg["attention_head_dim"] = [5, 10, 20, 20] logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")