1. Added an argument ema_decay_resume_model to load EMA model - it's loaded alongside main model, instead of copying normal model. It's optional, without loaded EMA model, it will copy the regular model to me the first EMA model, just like before.
2. Fixed findlast option for regular models not to load EMA models by default. 3. findlast can be used to load EMA model too when used with ema_decay_resume_model. 4. Added ema_device variable to store the device in torch type. 5. Cleaned prints and comments.
This commit is contained in:
parent
cf4a082e11
commit
5b1760fff2
119
train.py
119
train.py
|
@ -187,7 +187,7 @@ def set_args_12gb(args):
|
|||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
def find_last_checkpoint(logdir, is_ema=False):
|
||||
"""
|
||||
Finds the last checkpoint in the logdir, recursively
|
||||
"""
|
||||
|
@ -197,6 +197,12 @@ def find_last_checkpoint(logdir):
|
|||
for root, dirs, files in os.walk(logdir):
|
||||
for file in files:
|
||||
if os.path.basename(file) == "model_index.json":
|
||||
|
||||
if is_ema and (not root.endswith("_ema")):
|
||||
continue
|
||||
elif (not is_ema) and root.endswith("_ema"):
|
||||
continue
|
||||
|
||||
curr_date = os.path.getmtime(os.path.join(root,file))
|
||||
|
||||
if last_date is None or curr_date > last_date:
|
||||
|
@ -229,6 +235,11 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"):
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
|
||||
args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
|
@ -357,20 +368,17 @@ def log_args(log_writer, args):
|
|||
arglog += f"{arg}={value}, "
|
||||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device):
|
||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
if args.ema_decay_device != default_device:
|
||||
if ema_device != default_device:
|
||||
original_model_on_other_device = deepcopy(model)
|
||||
original_model_on_proper_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
|
||||
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
|
||||
del original_model_on_other_device
|
||||
need_to_delete_original = True
|
||||
|
||||
# original_model_on_proper_device = type(model)().to(args.ema_decay_device, dtype=model.dtype)
|
||||
# original_model_on_proper_device.load_state_dict(model.state_dict())
|
||||
|
||||
params = dict(original_model_on_proper_device.named_parameters())
|
||||
ema_params = dict(ema_model.named_parameters())
|
||||
|
||||
|
@ -418,9 +426,11 @@ def compute_snr(timesteps, noise_scheduler):
|
|||
snr[snr == 0] = minimal_value
|
||||
return snr
|
||||
|
||||
def load_train_json_from_file(args):
|
||||
def load_train_json_from_file(args, report_load = False):
|
||||
try:
|
||||
if report_load:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
read_json = json.load(f)
|
||||
|
||||
|
@ -461,6 +471,13 @@ def main(args):
|
|||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
def release_memory(model_to_delete, original_device):
|
||||
del model_to_delete
|
||||
gc.collect()
|
||||
|
||||
if 'cuda' in original_device.type:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
|
@ -477,7 +494,6 @@ def main(args):
|
|||
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
||||
nonlocal save_ckpt_dir
|
||||
nonlocal save_full_precision
|
||||
#nonlocal save_path
|
||||
nonlocal yaml_name
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
|
@ -555,6 +571,11 @@ def main(args):
|
|||
ed_optimizer.save(save_path)
|
||||
|
||||
|
||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
|
||||
ema_decay_model_loaded_from_file = False
|
||||
|
||||
if use_ema_dacay_training:
|
||||
ema_device = torch.device(args.ema_decay_device)
|
||||
|
||||
optimizer_state_path = None
|
||||
try:
|
||||
|
@ -580,6 +601,42 @@ def main(args):
|
|||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
if use_ema_dacay_training and args.ema_decay_resume_model:
|
||||
print(f"Loading EMA model: {args.ema_decay_resume_model}")
|
||||
ema_decay_model_loaded_from_file=True
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model)
|
||||
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model):
|
||||
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
|
||||
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
|
||||
|
||||
else:
|
||||
# try to download from HF using ema_decay_resume_model as a repo id
|
||||
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model)
|
||||
if ema_downloaded is None:
|
||||
raise ValueError(
|
||||
f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded")
|
||||
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
|
||||
text_encoder_ema = ema_pipe.text_encoder
|
||||
unet_ema = ema_pipe.unet
|
||||
del ema_pipe
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
unet_ema_current_device = next(unet_ema.parameters()).device
|
||||
if ema_device != unet_ema_current_device:
|
||||
unet_ema_on_wrong_device = unet_ema
|
||||
unet_ema = unet_ema.to(ema_device)
|
||||
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
|
||||
if ema_device != text_encoder_ema_current_device:
|
||||
text_encoder_ema_on_wrong_device = text_encoder_ema
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device)
|
||||
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
|
||||
|
||||
|
||||
if args.enable_zero_terminal_snr:
|
||||
# Use zero terminal SNR
|
||||
from utils.unet_utils import enforce_zero_terminal_snr
|
||||
|
@ -626,9 +683,10 @@ def main(args):
|
|||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
|
||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
if not ema_decay_model_loaded_from_file:
|
||||
logging.info(f"EMA decay enabled, creating EMA model.")
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -636,18 +694,16 @@ def main(args):
|
|||
unet_ema = deepcopy(unet)
|
||||
text_encoder_ema = deepcopy(text_encoder)
|
||||
else:
|
||||
# unet_ema = type(unet)().to(args.ema_decay_device, dtype=unet.dtype)
|
||||
# unet_ema.load_state_dict(unet.state_dict())
|
||||
#
|
||||
# text_encoder_ema = type(unet)().to(args.ema_decay_device, dtype=text_encoder.dtype)
|
||||
# text_encoder_ema.load_state_dict(text_encoder.state_dict())
|
||||
|
||||
unet_ema_first = deepcopy(unet)
|
||||
text_encoder_ema_first = deepcopy(text_encoder)
|
||||
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
|
||||
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
|
||||
del unet_ema_first
|
||||
del text_encoder_ema_first
|
||||
else:
|
||||
# Make sure correct types are used for models
|
||||
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
|
||||
|
||||
|
||||
try:
|
||||
|
@ -921,7 +977,7 @@ def main(args):
|
|||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model:
|
||||
models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device})
|
||||
models_info.append({"is_ema": True, "swap_required": ema_device != device})
|
||||
|
||||
for model_info in models_info:
|
||||
|
||||
|
@ -935,9 +991,9 @@ def main(args):
|
|||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet_unloaded = unet.to(args.ema_decay_device)
|
||||
unet_unloaded = unet.to(ema_device)
|
||||
del unet
|
||||
text_encoder_unloaded = text_encoder.to(args.ema_decay_device)
|
||||
text_encoder_unloaded = text_encoder.to(ema_device)
|
||||
del text_encoder
|
||||
|
||||
current_unet = unet_ema.to(device)
|
||||
|
@ -966,9 +1022,9 @@ def main(args):
|
|||
text_encoder = text_encoder_unloaded.to(device)
|
||||
del text_encoder_unloaded
|
||||
|
||||
unet_ema = current_unet.to(args.ema_decay_device)
|
||||
unet_ema = current_unet.to(ema_device)
|
||||
del current_unet
|
||||
text_encoder_ema = current_text_encoder.to(args.ema_decay_device)
|
||||
text_encoder_ema = current_text_encoder.to(ema_device)
|
||||
del current_text_encoder
|
||||
|
||||
gc.collect()
|
||||
|
@ -1047,17 +1103,17 @@ def main(args):
|
|||
|
||||
if args.ema_decay_rate != None:
|
||||
if ((global_step + 1) % args.ema_decay_interval) == 0:
|
||||
debug_start_time = time.time() # TODO: Remove time measurement when debug done
|
||||
# debug_start_time = time.time() # Measure time
|
||||
|
||||
if args.disable_unet_training != True:
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device)
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
if args.disable_textenc_training != True:
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device)
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
debug_end_time = time.time()
|
||||
debug_elapsed_time = debug_end_time - debug_start_time
|
||||
print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.")
|
||||
# debug_end_time = time.time() # Measure time
|
||||
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
|
||||
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
|
||||
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
@ -1180,7 +1236,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
||||
load_train_json_from_file(args)
|
||||
load_train_json_from_file(args, report_load=True)
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
|
@ -1232,11 +1288,12 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
|
||||
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.")
|
||||
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
|
||||
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
|
||||
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
|
||||
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
|
||||
argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue