fix bug with amp

This commit is contained in:
Victor Hall 2023-01-11 11:49:20 -05:00
parent 97a8c69451
commit ca6cd6c4e0
2 changed files with 13 additions and 11 deletions

View File

@ -91,13 +91,13 @@ def convert_to_hf(ckpt_path):
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
patch_unet(hf_cache, args.ed1_mode)
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
elif os.path.isdir(hf_cache):
patch_unet(hf_cache, args.ed1_mode)
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
else:
patch_unet(ckpt_path, args.ed1_mode)
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
return ckpt_path
def setup_local_logger(args):
@ -436,7 +436,7 @@ def main(args):
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if args.ed1_mode:
if args.ed1_mode and not args.lowvram:
unet.set_attention_slice(4)
if not args.disable_xformers and is_xformers_available():
@ -449,7 +449,7 @@ def main(args):
else:
logging.info("xformers not available or disabled")
default_lr = 3e-6
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
@ -669,10 +669,9 @@ def main(args):
step_start_time = time.time()
with torch.no_grad():
#with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
#with autocast(enabled=args.amp):
latents = vae.encode(pixel_values, return_dict=False)
with autocast(enabled=args.amp):
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215

View File

@ -17,7 +17,7 @@ import os
import json
import logging
def patch_unet(ckpt_path, force_sd1attn: bool = False):
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
"""
Patch the UNet to use updated attention heads for xformers support in FP32
"""
@ -27,7 +27,10 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False):
if force_sd1attn:
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
if low_vram:
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
else:
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
else:
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]