fix bug with amp
This commit is contained in:
parent
97a8c69451
commit
ca6cd6c4e0
17
train.py
17
train.py
|
@ -91,13 +91,13 @@ def convert_to_hf(ckpt_path):
|
|||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
patch_unet(hf_cache, args.ed1_mode)
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
elif os.path.isdir(hf_cache):
|
||||
patch_unet(hf_cache, args.ed1_mode)
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
else:
|
||||
patch_unet(ckpt_path, args.ed1_mode)
|
||||
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
|
||||
return ckpt_path
|
||||
|
||||
def setup_local_logger(args):
|
||||
|
@ -436,7 +436,7 @@ def main(args):
|
|||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.ed1_mode:
|
||||
if args.ed1_mode and not args.lowvram:
|
||||
unet.set_attention_slice(4)
|
||||
|
||||
if not args.disable_xformers and is_xformers_available():
|
||||
|
@ -449,7 +449,7 @@ def main(args):
|
|||
else:
|
||||
logging.info("xformers not available or disabled")
|
||||
|
||||
default_lr = 3e-6
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
|
@ -669,10 +669,9 @@ def main(args):
|
|||
step_start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
#with autocast():
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
#with autocast(enabled=args.amp):
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
import json
|
||||
import logging
|
||||
|
||||
def patch_unet(ckpt_path, force_sd1attn: bool = False):
|
||||
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
||||
"""
|
||||
Patch the UNet to use updated attention heads for xformers support in FP32
|
||||
"""
|
||||
|
@ -27,7 +27,10 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False):
|
|||
|
||||
|
||||
if force_sd1attn:
|
||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||
if low_vram:
|
||||
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
|
||||
else:
|
||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||
else:
|
||||
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
||||
|
||||
|
|
Loading…
Reference in New Issue