diff --git a/train.py b/train.py index 88e4b51..eac4318 100644 --- a/train.py +++ b/train.py @@ -91,13 +91,13 @@ def convert_to_hf(ckpt_path): else: logging.info(f"Found cached checkpoint at {hf_cache}") - patch_unet(hf_cache, args.ed1_mode) + patch_unet(hf_cache, args.ed1_mode, args.lowvram) return hf_cache elif os.path.isdir(hf_cache): - patch_unet(hf_cache, args.ed1_mode) + patch_unet(hf_cache, args.ed1_mode, args.lowvram) return hf_cache else: - patch_unet(ckpt_path, args.ed1_mode) + patch_unet(ckpt_path, args.ed1_mode, args.lowvram) return ckpt_path def setup_local_logger(args): @@ -436,7 +436,7 @@ def main(args): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - if args.ed1_mode: + if args.ed1_mode and not args.lowvram: unet.set_attention_slice(4) if not args.disable_xformers and is_xformers_available(): @@ -449,7 +449,7 @@ def main(args): else: logging.info("xformers not available or disabled") - default_lr = 3e-6 + default_lr = 2e-6 curr_lr = args.lr if args.lr is not None else default_lr # vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) @@ -669,10 +669,9 @@ def main(args): step_start_time = time.time() with torch.no_grad(): - #with autocast(): - pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) - #with autocast(enabled=args.amp): - latents = vae.encode(pixel_values, return_dict=False) + with autocast(enabled=args.amp): + pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) + latents = vae.encode(pixel_values, return_dict=False) del pixel_values latents = latents[0].sample() * 0.18215 diff --git a/utils/patch_unet.py b/utils/patch_unet.py index 7520e85..c9392b0 100644 --- a/utils/patch_unet.py +++ b/utils/patch_unet.py @@ -17,7 +17,7 @@ import os import json import logging -def patch_unet(ckpt_path, force_sd1attn: bool = False): +def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False): """ Patch the UNet to use updated attention heads for xformers support in FP32 """ @@ -27,7 +27,10 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False): if force_sd1attn: - unet_cfg["attention_head_dim"] = [8, 8, 8, 8] + if low_vram: + unet_cfg["attention_head_dim"] = [5, 8, 8, 8] + else: + unet_cfg["attention_head_dim"] = [8, 8, 8, 8] else: unet_cfg["attention_head_dim"] = [5, 10, 20, 20]