Merge pull request #137 from tjennings/main
support for saving the optimizer state
This commit is contained in:
commit
f9958320f6
|
@ -13,3 +13,4 @@
|
||||||
/.vscode/**
|
/.vscode/**
|
||||||
.ssh_config
|
.ssh_config
|
||||||
*inference*.yaml
|
*inference*.yaml
|
||||||
|
.idea
|
||||||
|
|
79
train.py
79
train.py
|
@ -77,7 +77,7 @@ def convert_to_hf(ckpt_path):
|
||||||
hf_cache = get_hf_ckpt_cache_path(ckpt_path)
|
hf_cache = get_hf_ckpt_cache_path(ckpt_path)
|
||||||
from utils.analyze_unet import get_attn_yaml
|
from utils.analyze_unet import get_attn_yaml
|
||||||
|
|
||||||
if os.path.isfile(ckpt_path):
|
if os.path.isfile(ckpt_path):
|
||||||
if not os.path.exists(hf_cache):
|
if not os.path.exists(hf_cache):
|
||||||
os.makedirs(hf_cache)
|
os.makedirs(hf_cache)
|
||||||
logging.info(f"Converting {ckpt_path} to Diffusers format")
|
logging.info(f"Converting {ckpt_path} to Diffusers format")
|
||||||
|
@ -89,7 +89,7 @@ def convert_to_hf(ckpt_path):
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||||
|
|
||||||
is_sd1attn, yaml = get_attn_yaml(hf_cache)
|
is_sd1attn, yaml = get_attn_yaml(hf_cache)
|
||||||
return hf_cache, is_sd1attn, yaml
|
return hf_cache, is_sd1attn, yaml
|
||||||
elif os.path.isdir(hf_cache):
|
elif os.path.isdir(hf_cache):
|
||||||
|
@ -180,7 +180,7 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||||
|
|
||||||
def set_args_12gb(args):
|
def set_args_12gb(args):
|
||||||
logging.info(" Setting args to 12GB mode")
|
logging.info(" Setting args to 12GB mode")
|
||||||
if not args.gradient_checkpointing:
|
if not args.gradient_checkpointing:
|
||||||
logging.info(" - Overiding gradient checkpointing to True")
|
logging.info(" - Overiding gradient checkpointing to True")
|
||||||
args.gradient_checkpointing = True
|
args.gradient_checkpointing = True
|
||||||
if args.batch_size > 2:
|
if args.batch_size > 2:
|
||||||
|
@ -279,7 +279,7 @@ def setup_args(args):
|
||||||
|
|
||||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||||
|
|
||||||
args.aspects = aspects.get_aspect_buckets(args.resolution)
|
args.aspects = aspects.get_aspect_buckets(args.resolution)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -304,13 +304,13 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(100)
|
scaler.set_growth_interval(100)
|
||||||
|
|
||||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
||||||
for item in items:
|
for item in items:
|
||||||
if item.error is not None:
|
if item.error is not None:
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||||
logging.error(f" *** exception: {item.error}")
|
logging.error(f" *** exception: {item.error}")
|
||||||
|
|
||||||
undersized_items = [item for item in items if item.is_undersized]
|
undersized_items = [item for item in items if item.is_undersized]
|
||||||
|
|
||||||
if len(undersized_items) > 0:
|
if len(undersized_items) > 0:
|
||||||
|
@ -322,21 +322,21 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
|
||||||
for undersized_item in undersized_items:
|
for undersized_item in undersized_items:
|
||||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||||
undersized_images_file.write(message)
|
undersized_images_file.write(message)
|
||||||
|
|
||||||
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
||||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||||
logging.info(" Preloading images...")
|
logging.info(" Preloading images...")
|
||||||
|
|
||||||
resolved_items = resolver.resolve(args.data_root, args)
|
resolved_items = resolver.resolve(args.data_root, args)
|
||||||
report_image_train_item_problems(log_folder, resolved_items)
|
report_image_train_item_problems(log_folder, resolved_items)
|
||||||
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
||||||
|
|
||||||
# Remove erroneous items
|
# Remove erroneous items
|
||||||
image_train_items = [item for item in resolved_items if item.error is None]
|
image_train_items = [item for item in resolved_items if item.error is None]
|
||||||
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
||||||
|
|
||||||
return image_train_items
|
return image_train_items
|
||||||
|
|
||||||
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
||||||
if args.write_schedule:
|
if args.write_schedule:
|
||||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||||
|
@ -365,7 +365,7 @@ def log_args(log_writer, args):
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
Main entry point
|
Main entry point
|
||||||
"""
|
"""
|
||||||
log_time = setup_local_logger(args)
|
log_time = setup_local_logger(args)
|
||||||
args = setup_args(args)
|
args = setup_args(args)
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ def main(args):
|
||||||
os.makedirs(log_folder)
|
os.makedirs(log_folder)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||||
"""
|
"""
|
||||||
Save the model to disk
|
Save the model to disk
|
||||||
"""
|
"""
|
||||||
|
@ -415,13 +415,13 @@ def main(args):
|
||||||
)
|
)
|
||||||
pipeline.save_pretrained(save_path)
|
pipeline.save_pretrained(save_path)
|
||||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||||
|
|
||||||
if save_ckpt_dir is not None:
|
if save_ckpt_dir is not None:
|
||||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||||
else:
|
else:
|
||||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||||
save_ckpt_dir = os.curdir
|
save_ckpt_dir = os.curdir
|
||||||
|
|
||||||
half = not save_full_precision
|
half = not save_full_precision
|
||||||
|
|
||||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||||
|
@ -432,10 +432,11 @@ def main(args):
|
||||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||||
shutil.copyfile(yaml_name, yaml_save_path)
|
shutil.copyfile(yaml_name, yaml_save_path)
|
||||||
|
|
||||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
|
||||||
# if self.save_optimizer_flag:
|
if save_optimizer_flag:
|
||||||
# logging.info(f" Saving optimizer state to {save_path}")
|
optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
logging.info(f" Saving optimizer state to {save_path}")
|
||||||
|
save_optimizer(optimizer, optimizer_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
@ -446,6 +447,10 @@ def main(args):
|
||||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||||
|
|
||||||
|
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
||||||
|
if not os.path.exists(optimizer_state_path):
|
||||||
|
optimizer_state_path = None
|
||||||
else:
|
else:
|
||||||
# try to download from HF using resume_ckpt as a repo id
|
# try to download from HF using resume_ckpt as a repo id
|
||||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||||
|
@ -572,7 +577,7 @@ def main(args):
|
||||||
betas=(betas[0], betas[1]),
|
betas=(betas[0], betas[1]),
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
)
|
)
|
||||||
elif optimizer_name in ["adamw"]:
|
elif optimizer_name in ["adamw"]:
|
||||||
opt_class = torch.optim.AdamW
|
opt_class = torch.optim.AdamW
|
||||||
else:
|
else:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -588,6 +593,10 @@ def main(args):
|
||||||
amsgrad=False,
|
amsgrad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if optimizer_state_path is not None:
|
||||||
|
logging.info(f"Loading optimizer state from {optimizer_state_path}")
|
||||||
|
load_optimizer(optimizer, optimizer_state_path)
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
@ -618,7 +627,7 @@ def main(args):
|
||||||
rated_dataset=args.rated_dataset,
|
rated_dataset=args.rated_dataset,
|
||||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.benchmark = False
|
torch.cuda.benchmark = False
|
||||||
|
|
||||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||||
|
@ -634,7 +643,7 @@ def main(args):
|
||||||
num_warmup_steps=lr_warmup_steps,
|
num_warmup_steps=lr_warmup_steps,
|
||||||
num_training_steps=args.lr_decay_steps,
|
num_training_steps=args.lr_decay_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
log_args(log_writer, args)
|
log_args(log_writer, args)
|
||||||
|
|
||||||
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
||||||
|
@ -673,14 +682,14 @@ def main(args):
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
|
||||||
exit(_SIGTERM_EXIT_CODE)
|
exit(_SIGTERM_EXIT_CODE)
|
||||||
else:
|
else:
|
||||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigterm_handler)
|
signal.signal(signal.SIGINT, sigterm_handler)
|
||||||
|
|
||||||
if not os.path.exists(f"{log_folder}/samples/"):
|
if not os.path.exists(f"{log_folder}/samples/"):
|
||||||
os.makedirs(f"{log_folder}/samples/")
|
os.makedirs(f"{log_folder}/samples/")
|
||||||
|
|
||||||
|
@ -693,7 +702,7 @@ def main(args):
|
||||||
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||||
|
|
||||||
unet.train() if not args.disable_unet_training else unet.eval()
|
unet.train() if not args.disable_unet_training else unet.eval()
|
||||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||||
|
|
||||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||||
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
||||||
|
@ -701,7 +710,7 @@ def main(args):
|
||||||
logging.info(f" scheduler: {noise_scheduler.__class__}")
|
logging.info(f" scheduler: {noise_scheduler.__class__}")
|
||||||
|
|
||||||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
@ -738,13 +747,13 @@ def main(args):
|
||||||
del pixel_values
|
del pixel_values
|
||||||
latents = latents[0].sample() * 0.18215
|
latents = latents[0].sample() * 0.18215
|
||||||
|
|
||||||
if zero_frequency_noise_ratio > 0.0:
|
if zero_frequency_noise_ratio > 0.0:
|
||||||
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||||
else:
|
else:
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
bsz = latents.shape[0]
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
@ -808,7 +817,7 @@ def main(args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
|
@ -887,12 +896,12 @@ def main(args):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -915,14 +924,14 @@ def main(args):
|
||||||
|
|
||||||
if validator:
|
if validator:
|
||||||
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
# end of training
|
# end of training
|
||||||
|
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
total_elapsed_time = time.time() - training_start_time
|
total_elapsed_time = time.time() - training_start_time
|
||||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||||
|
@ -932,7 +941,7 @@ def main(args):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||||
|
|
Loading…
Reference in New Issue