1. Added an argument ema_decay_resume_model to load EMA model - it's loaded alongside main model, instead of copying normal model. It's optional, without loaded EMA model, it will copy the regular model to me the first EMA model, just like before.

2. Fixed findlast option for regular models not to load EMA models by default.
3. findlast can be used to load EMA model too when used with ema_decay_resume_model.
4. Added ema_device variable to store the device in torch type.
5. Cleaned prints and comments.
This commit is contained in:
alexds9 2023-09-07 19:53:20 +03:00
parent cf4a082e11
commit 5b1760fff2
1 changed files with 99 additions and 42 deletions

141
train.py
View File

@ -187,7 +187,7 @@ def set_args_12gb(args):
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir):
def find_last_checkpoint(logdir, is_ema=False):
"""
Finds the last checkpoint in the logdir, recursively
"""
@ -197,6 +197,12 @@ def find_last_checkpoint(logdir):
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
if is_ema and (not root.endswith("_ema")):
continue
elif (not is_ema) and root.endswith("_ema"):
continue
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
@ -229,6 +235,11 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"):
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
if args.lowvram:
set_args_12gb(args)
@ -357,20 +368,17 @@ def log_args(log_writer, args):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device):
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
need_to_delete_original = False
if args.ema_decay_device != default_device:
if ema_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_proper_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
# original_model_on_proper_device = type(model)().to(args.ema_decay_device, dtype=model.dtype)
# original_model_on_proper_device.load_state_dict(model.state_dict())
params = dict(original_model_on_proper_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
@ -418,9 +426,11 @@ def compute_snr(timesteps, noise_scheduler):
snr[snr == 0] = minimal_value
return snr
def load_train_json_from_file(args):
def load_train_json_from_file(args, report_load = False):
try:
print(f"Loading training config from {args.config}.")
if report_load:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
@ -461,6 +471,13 @@ def main(args):
if not os.path.exists(log_folder):
os.makedirs(log_folder)
def release_memory(model_to_delete, original_device):
del model_to_delete
gc.collect()
if 'cuda' in original_device.type:
torch.cuda.empty_cache()
@torch.no_grad()
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
@ -477,7 +494,6 @@ def main(args):
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
nonlocal save_ckpt_dir
nonlocal save_full_precision
#nonlocal save_path
nonlocal yaml_name
if save_ckpt_dir is not None:
@ -555,6 +571,11 @@ def main(args):
ed_optimizer.save(save_path)
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
ema_decay_model_loaded_from_file = False
if use_ema_dacay_training:
ema_device = torch.device(args.ema_decay_device)
optimizer_state_path = None
try:
@ -580,6 +601,42 @@ def main(args):
unet = pipe.unet
del pipe
if use_ema_dacay_training and args.ema_decay_resume_model:
print(f"Loading EMA model: {args.ema_decay_resume_model}")
ema_decay_model_loaded_from_file=True
hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model)
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model):
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
else:
# try to download from HF using ema_decay_resume_model as a repo id
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model)
if ema_downloaded is None:
raise ValueError(
f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded")
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
text_encoder_ema = ema_pipe.text_encoder
unet_ema = ema_pipe.unet
del ema_pipe
# Make sure EMA model is on proper device, and memory released if moved
unet_ema_current_device = next(unet_ema.parameters()).device
if ema_device != unet_ema_current_device:
unet_ema_on_wrong_device = unet_ema
unet_ema = unet_ema.to(ema_device)
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
# Make sure EMA model is on proper device, and memory released if moved
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
if ema_device != text_encoder_ema_current_device:
text_encoder_ema_on_wrong_device = text_encoder_ema
text_encoder_ema = text_encoder_ema.to(ema_device)
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
if args.enable_zero_terminal_snr:
# Use zero terminal SNR
from utils.unet_utils import enforce_zero_terminal_snr
@ -626,28 +683,27 @@ def main(args):
text_encoder = text_encoder.to(device, dtype=torch.float32)
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
if use_ema_dacay_training:
logging.info(f"EMA decay enabled, creating EMA model.")
if not ema_decay_model_loaded_from_file:
logging.info(f"EMA decay enabled, creating EMA model.")
with torch.no_grad():
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
# unet_ema = type(unet)().to(args.ema_decay_device, dtype=unet.dtype)
# unet_ema.load_state_dict(unet.state_dict())
#
# text_encoder_ema = type(unet)().to(args.ema_decay_device, dtype=text_encoder.dtype)
# text_encoder_ema.load_state_dict(text_encoder.state_dict())
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
with torch.no_grad():
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
else:
# Make sure correct types are used for models
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
try:
@ -921,7 +977,7 @@ def main(args):
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model:
models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device})
models_info.append({"is_ema": True, "swap_required": ema_device != device})
for model_info in models_info:
@ -935,9 +991,9 @@ def main(args):
if model_info["swap_required"]:
with torch.no_grad():
unet_unloaded = unet.to(args.ema_decay_device)
unet_unloaded = unet.to(ema_device)
del unet
text_encoder_unloaded = text_encoder.to(args.ema_decay_device)
text_encoder_unloaded = text_encoder.to(ema_device)
del text_encoder
current_unet = unet_ema.to(device)
@ -966,9 +1022,9 @@ def main(args):
text_encoder = text_encoder_unloaded.to(device)
del text_encoder_unloaded
unet_ema = current_unet.to(args.ema_decay_device)
unet_ema = current_unet.to(ema_device)
del current_unet
text_encoder_ema = current_text_encoder.to(args.ema_decay_device)
text_encoder_ema = current_text_encoder.to(ema_device)
del current_text_encoder
gc.collect()
@ -1047,17 +1103,17 @@ def main(args):
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_decay_interval) == 0:
debug_start_time = time.time() # TODO: Remove time measurement when debug done
# debug_start_time = time.time() # Measure time
if args.disable_unet_training != True:
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device)
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
if args.disable_textenc_training != True:
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device)
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
debug_end_time = time.time()
debug_elapsed_time = debug_end_time - debug_start_time
print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.")
# debug_end_time = time.time() # Measure time
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
loss_step = loss.detach().item()
@ -1180,7 +1236,7 @@ if __name__ == "__main__":
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()
load_train_json_from_file(args)
load_train_json_from_file(args, report_load=True)
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
@ -1232,11 +1288,12 @@ if __name__ == "__main__":
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")