save yaml with ckpt files for easier loading
This commit is contained in:
parent
af503f413c
commit
23faf05512
81
train.py
81
train.py
|
@ -22,9 +22,10 @@ import logging
|
|||
import time
|
||||
import gc
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from colorama import Fore, Style, Cursor
|
||||
|
@ -92,14 +93,14 @@ def convert_to_hf(ckpt_path):
|
|||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
elif os.path.isdir(hf_cache):
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
else:
|
||||
is_sd1attn = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn, yaml
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
|
@ -275,6 +276,28 @@ def setup_args(args):
|
|||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 250 or (epoch >= 2 and step == 1):
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 500 or (epoch >= 4 and step == 1):
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000 or (epoch >= 8 and step == 1):
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 3000 or (epoch >= 15 and step == 1):
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -296,12 +319,12 @@ def main(args):
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -322,6 +345,7 @@ def main(args):
|
|||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
|
@ -331,8 +355,13 @@ def main(args):
|
|||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
|
||||
if yaml:
|
||||
yaml_save_path = f"{os.path.basename(save_path)}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml, yaml_save_path)
|
||||
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
@ -438,13 +467,14 @@ def main(args):
|
|||
del images
|
||||
|
||||
try:
|
||||
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
|
||||
hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||
logging.info(f" Inferred yaml: {yaml}, attention head type: {'sd1' if is_sd1attn else 'sd2'}")
|
||||
except:
|
||||
logging.ERROR(" * Failed to load checkpoint *")
|
||||
|
||||
|
@ -486,7 +516,7 @@ def main(args):
|
|||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 1e-8
|
||||
epsilon = 2e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
|
@ -562,7 +592,6 @@ def main(args):
|
|||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
@ -649,13 +678,12 @@ def main(args):
|
|||
|
||||
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
#enabled=True,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=1.8,
|
||||
backoff_factor=1.0/1.8,
|
||||
growth_interval=50,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
|
@ -813,23 +841,16 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
||||
if global_step == 500:
|
||||
scaler.set_growth_factor(1.4)
|
||||
scaler.set_backoff_factor(1/1.4)
|
||||
if global_step == 1000:
|
||||
scaler.set_growth_factor(1.2)
|
||||
scaler.set_backoff_factor(1/1.2)
|
||||
scaler.set_growth_interval(100)
|
||||
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -849,7 +870,7 @@ def main(args):
|
|||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -859,7 +880,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
|
|
@ -25,9 +25,25 @@ def patch_unet(ckpt_path):
|
|||
with open(unet_cfg_path, "r") as f:
|
||||
unet_cfg = json.load(f)
|
||||
|
||||
scheduler_cfg_path = os.path.join(ckpt_path, "scheduler", "scheduler_config.json")
|
||||
with open(scheduler_cfg_path, "r") as f:
|
||||
scheduler_cfg = json.load(f)
|
||||
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
|
||||
|
||||
prediction_type = scheduler_cfg["prediction_type"]
|
||||
|
||||
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
||||
|
||||
return is_sd1attn
|
||||
yaml = ''
|
||||
if prediction_type in ["v_prediction","v-prediction"] and not is_sd1attn:
|
||||
yaml = "v2-inference-v.yaml"
|
||||
elif prediction_type == "epsilon" and not is_sd1attn:
|
||||
yaml = "v2-inference.yaml"
|
||||
elif prediction_type == "epsilon" and is_sd1attn:
|
||||
yaml = "v2-inference.yaml"
|
||||
else:
|
||||
raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}")
|
||||
|
||||
return is_sd1attn, yaml
|
||||
|
|
Loading…
Reference in New Issue