fix bug with amp
This commit is contained in:
parent
97a8c69451
commit
ca6cd6c4e0
17
train.py
17
train.py
|
@ -91,13 +91,13 @@ def convert_to_hf(ckpt_path):
|
||||||
else:
|
else:
|
||||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||||
|
|
||||||
patch_unet(hf_cache, args.ed1_mode)
|
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||||
return hf_cache
|
return hf_cache
|
||||||
elif os.path.isdir(hf_cache):
|
elif os.path.isdir(hf_cache):
|
||||||
patch_unet(hf_cache, args.ed1_mode)
|
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||||
return hf_cache
|
return hf_cache
|
||||||
else:
|
else:
|
||||||
patch_unet(ckpt_path, args.ed1_mode)
|
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
def setup_local_logger(args):
|
def setup_local_logger(args):
|
||||||
|
@ -436,7 +436,7 @@ def main(args):
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
if args.ed1_mode:
|
if args.ed1_mode and not args.lowvram:
|
||||||
unet.set_attention_slice(4)
|
unet.set_attention_slice(4)
|
||||||
|
|
||||||
if not args.disable_xformers and is_xformers_available():
|
if not args.disable_xformers and is_xformers_available():
|
||||||
|
@ -449,7 +449,7 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
logging.info("xformers not available or disabled")
|
logging.info("xformers not available or disabled")
|
||||||
|
|
||||||
default_lr = 3e-6
|
default_lr = 2e-6
|
||||||
curr_lr = args.lr if args.lr is not None else default_lr
|
curr_lr = args.lr if args.lr is not None else default_lr
|
||||||
|
|
||||||
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||||
|
@ -669,10 +669,9 @@ def main(args):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
#with autocast():
|
with autocast(enabled=args.amp):
|
||||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
#with autocast(enabled=args.amp):
|
latents = vae.encode(pixel_values, return_dict=False)
|
||||||
latents = vae.encode(pixel_values, return_dict=False)
|
|
||||||
del pixel_values
|
del pixel_values
|
||||||
latents = latents[0].sample() * 0.18215
|
latents = latents[0].sample() * 0.18215
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
def patch_unet(ckpt_path, force_sd1attn: bool = False):
|
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
||||||
"""
|
"""
|
||||||
Patch the UNet to use updated attention heads for xformers support in FP32
|
Patch the UNet to use updated attention heads for xformers support in FP32
|
||||||
"""
|
"""
|
||||||
|
@ -27,7 +27,10 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False):
|
||||||
|
|
||||||
|
|
||||||
if force_sd1attn:
|
if force_sd1attn:
|
||||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
if low_vram:
|
||||||
|
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
|
||||||
|
else:
|
||||||
|
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||||
else:
|
else:
|
||||||
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue