save yaml with ckpt files for easier loading

This commit is contained in:
Victor Hall 2023-01-18 13:07:05 -05:00
parent af503f413c
commit 23faf05512
2 changed files with 68 additions and 31 deletions

View File

@ -22,9 +22,10 @@ import logging
import time
import gc
import random
import shutil
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from colorama import Fore, Style, Cursor
@ -92,14 +93,14 @@ def convert_to_hf(ckpt_path):
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn, yaml
elif os.path.isdir(hf_cache):
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn, yaml
else:
is_sd1attn = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn
is_sd1attn, yaml = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn, yaml
def setup_local_logger(args):
"""
@ -275,6 +276,28 @@ def setup_args(args):
return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
if global_step == 250 or (epoch >= 2 and step == 1):
factor = 1.8
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 500 or (epoch >= 4 and step == 1):
factor = 1.6
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 1000 or (epoch >= 8 and step == 1):
factor = 1.3
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
if global_step == 3000 or (epoch >= 15 and step == 1):
factor = 1.15
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def main(args):
"""
Main entry point
@ -296,12 +319,12 @@ def main(args):
torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
"""
Save the model to disk
"""
@ -322,6 +345,7 @@ def main(args):
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
@ -331,8 +355,13 @@ def main(args):
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
if yaml:
yaml_save_path = f"{os.path.basename(save_path)}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
@ -438,13 +467,14 @@ def main(args):
del images
try:
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
logging.info(f" Inferred yaml: {yaml}, attention head type: {'sd1' if is_sd1attn else 'sd2'}")
except:
logging.ERROR(" * Failed to load checkpoint *")
@ -486,7 +516,7 @@ def main(args):
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp:
epsilon = 1e-8
epsilon = 2e-8
weight_decay = 0.01
if args.useadam8bit:
@ -562,7 +592,6 @@ def main(args):
log_args(log_writer, args)
"""
Train the model
@ -649,13 +678,12 @@ def main(args):
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
scaler = GradScaler(
enabled=args.amp,
#enabled=True,
init_scale=2**17.5,
growth_factor=1.8,
backoff_factor=1.0/1.8,
growth_interval=50,
growth_factor=2,
backoff_factor=1.0/2,
growth_interval=25,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
@ -813,23 +841,16 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
del batch
global_step += 1
if global_step == 500:
scaler.set_growth_factor(1.4)
scaler.set_backoff_factor(1/1.4)
if global_step == 1000:
scaler.set_growth_factor(1.2)
scaler.set_backoff_factor(1/1.2)
scaler.set_growth_interval(100)
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
# end of step
steps_pbar.close()
@ -849,7 +870,7 @@ def main(args):
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -859,7 +880,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")

View File

@ -25,9 +25,25 @@ def patch_unet(ckpt_path):
with open(unet_cfg_path, "r") as f:
unet_cfg = json.load(f)
scheduler_cfg_path = os.path.join(ckpt_path, "scheduler", "scheduler_config.json")
with open(scheduler_cfg_path, "r") as f:
scheduler_cfg = json.load(f)
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
prediction_type = scheduler_cfg["prediction_type"]
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
return is_sd1attn
yaml = ''
if prediction_type in ["v_prediction","v-prediction"] and not is_sd1attn:
yaml = "v2-inference-v.yaml"
elif prediction_type == "epsilon" and not is_sd1attn:
yaml = "v2-inference.yaml"
elif prediction_type == "epsilon" and is_sd1attn:
yaml = "v2-inference.yaml"
else:
raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}")
return is_sd1attn, yaml