diff --git a/train.py b/train.py index 21d0466..844b9cb 100644 --- a/train.py +++ b/train.py @@ -187,7 +187,7 @@ def set_args_12gb(args): logging.info(" - Overiding resolution to max 512") args.resolution = 512 -def find_last_checkpoint(logdir): +def find_last_checkpoint(logdir, is_ema=False): """ Finds the last checkpoint in the logdir, recursively """ @@ -197,6 +197,12 @@ def find_last_checkpoint(logdir): for root, dirs, files in os.walk(logdir): for file in files: if os.path.basename(file) == "model_index.json": + + if is_ema and (not root.endswith("_ema")): + continue + elif (not is_ema) and root.endswith("_ema"): + continue + curr_date = os.path.getmtime(os.path.join(root,file)) if last_date is None or curr_date > last_date: @@ -229,6 +235,11 @@ def setup_args(args): # find the last checkpoint in the logdir args.resume_ckpt = find_last_checkpoint(args.logdir) + if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"): + logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}") + + args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True) + if args.lowvram: set_args_12gb(args) @@ -357,20 +368,17 @@ def log_args(log_writer, args): arglog += f"{arg}={value}, " log_writer.add_text("config", arglog) -def update_ema(model, ema_model, decay, default_device): +def update_ema(model, ema_model, decay, default_device, ema_device): with torch.no_grad(): original_model_on_proper_device = model need_to_delete_original = False - if args.ema_decay_device != default_device: + if ema_device != default_device: original_model_on_other_device = deepcopy(model) - original_model_on_proper_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype) + original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype) del original_model_on_other_device need_to_delete_original = True - # original_model_on_proper_device = type(model)().to(args.ema_decay_device, dtype=model.dtype) - # original_model_on_proper_device.load_state_dict(model.state_dict()) - params = dict(original_model_on_proper_device.named_parameters()) ema_params = dict(ema_model.named_parameters()) @@ -418,9 +426,11 @@ def compute_snr(timesteps, noise_scheduler): snr[snr == 0] = minimal_value return snr -def load_train_json_from_file(args): +def load_train_json_from_file(args, report_load = False): try: - print(f"Loading training config from {args.config}.") + if report_load: + print(f"Loading training config from {args.config}.") + with open(args.config, 'rt') as f: read_json = json.load(f) @@ -461,6 +471,13 @@ def main(args): if not os.path.exists(log_folder): os.makedirs(log_folder) + def release_memory(model_to_delete, original_device): + del model_to_delete + gc.collect() + + if 'cuda' in original_device.type: + torch.cuda.empty_cache() + @torch.no_grad() def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False, save_ckpt=True): @@ -477,7 +494,6 @@ def main(args): def save_ckpt_file(diffusers_model_path, sd_ckpt_path): nonlocal save_ckpt_dir nonlocal save_full_precision - #nonlocal save_path nonlocal yaml_name if save_ckpt_dir is not None: @@ -555,6 +571,11 @@ def main(args): ed_optimizer.save(save_path) + use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None) + ema_decay_model_loaded_from_file = False + + if use_ema_dacay_training: + ema_device = torch.device(args.ema_decay_device) optimizer_state_path = None try: @@ -580,6 +601,42 @@ def main(args): unet = pipe.unet del pipe + if use_ema_dacay_training and args.ema_decay_resume_model: + print(f"Loading EMA model: {args.ema_decay_resume_model}") + ema_decay_model_loaded_from_file=True + hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model) + + if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model): + ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt) + text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder") + unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet") + + else: + # try to download from HF using ema_decay_resume_model as a repo id + ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model) + if ema_downloaded is None: + raise ValueError( + f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded") + ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded + text_encoder_ema = ema_pipe.text_encoder + unet_ema = ema_pipe.unet + del ema_pipe + + # Make sure EMA model is on proper device, and memory released if moved + unet_ema_current_device = next(unet_ema.parameters()).device + if ema_device != unet_ema_current_device: + unet_ema_on_wrong_device = unet_ema + unet_ema = unet_ema.to(ema_device) + release_memory(unet_ema_on_wrong_device, unet_ema_current_device) + + # Make sure EMA model is on proper device, and memory released if moved + text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device + if ema_device != text_encoder_ema_current_device: + text_encoder_ema_on_wrong_device = text_encoder_ema + text_encoder_ema = text_encoder_ema.to(ema_device) + release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device) + + if args.enable_zero_terminal_snr: # Use zero terminal SNR from utils.unet_utils import enforce_zero_terminal_snr @@ -626,28 +683,27 @@ def main(args): text_encoder = text_encoder.to(device, dtype=torch.float32) - use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None) + if use_ema_dacay_training: - logging.info(f"EMA decay enabled, creating EMA model.") + if not ema_decay_model_loaded_from_file: + logging.info(f"EMA decay enabled, creating EMA model.") - with torch.no_grad(): - if args.ema_decay_device == device: - unet_ema = deepcopy(unet) - text_encoder_ema = deepcopy(text_encoder) - else: - # unet_ema = type(unet)().to(args.ema_decay_device, dtype=unet.dtype) - # unet_ema.load_state_dict(unet.state_dict()) - # - # text_encoder_ema = type(unet)().to(args.ema_decay_device, dtype=text_encoder.dtype) - # text_encoder_ema.load_state_dict(text_encoder.state_dict()) - - unet_ema_first = deepcopy(unet) - text_encoder_ema_first = deepcopy(text_encoder) - unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype) - text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype) - del unet_ema_first - del text_encoder_ema_first + with torch.no_grad(): + if args.ema_decay_device == device: + unet_ema = deepcopy(unet) + text_encoder_ema = deepcopy(text_encoder) + else: + unet_ema_first = deepcopy(unet) + text_encoder_ema_first = deepcopy(text_encoder) + unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype) + text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype) + del unet_ema_first + del text_encoder_ema_first + else: + # Make sure correct types are used for models + unet_ema = unet_ema.to(ema_device, dtype=unet.dtype) + text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype) try: @@ -921,7 +977,7 @@ def main(args): models_info.append({"is_ema": False, "swap_required": False}) if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model: - models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device}) + models_info.append({"is_ema": True, "swap_required": ema_device != device}) for model_info in models_info: @@ -935,9 +991,9 @@ def main(args): if model_info["swap_required"]: with torch.no_grad(): - unet_unloaded = unet.to(args.ema_decay_device) + unet_unloaded = unet.to(ema_device) del unet - text_encoder_unloaded = text_encoder.to(args.ema_decay_device) + text_encoder_unloaded = text_encoder.to(ema_device) del text_encoder current_unet = unet_ema.to(device) @@ -966,9 +1022,9 @@ def main(args): text_encoder = text_encoder_unloaded.to(device) del text_encoder_unloaded - unet_ema = current_unet.to(args.ema_decay_device) + unet_ema = current_unet.to(ema_device) del current_unet - text_encoder_ema = current_text_encoder.to(args.ema_decay_device) + text_encoder_ema = current_text_encoder.to(ema_device) del current_text_encoder gc.collect() @@ -1047,17 +1103,17 @@ def main(args): if args.ema_decay_rate != None: if ((global_step + 1) % args.ema_decay_interval) == 0: - debug_start_time = time.time() # TODO: Remove time measurement when debug done + # debug_start_time = time.time() # Measure time if args.disable_unet_training != True: - update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device) + update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device) if args.disable_textenc_training != True: - update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device) + update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device) - debug_end_time = time.time() - debug_elapsed_time = debug_end_time - debug_start_time - print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") + # debug_end_time = time.time() # Measure time + # debug_elapsed_time = debug_end_time - debug_start_time # Measure time + # print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time loss_step = loss.detach().item() @@ -1180,7 +1236,7 @@ if __name__ == "__main__": argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") args, argv = argparser.parse_known_args() - load_train_json_from_file(args) + load_train_json_from_file(args, report_load=True) argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") @@ -1232,11 +1288,12 @@ if __name__ == "__main__": argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15") argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule") argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.") - argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.") + argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.") argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.") argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.") argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model") argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training") + argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay") argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556") argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")