clean line

This commit is contained in:
Victor Hall 2023-01-06 16:33:33 -05:00
parent 8fdd151991
commit 906d282023
1 changed files with 32 additions and 29 deletions

View File

@ -91,12 +91,13 @@ def convert_to_hf(ckpt_path):
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
patch_unet(hf_cache)
patch_unet(hf_cache, args.ed1_mode)
return hf_cache
elif os.path.isdir(hf_cache):
patch_unet(hf_cache)
patch_unet(hf_cache, args.ed1_mode)
return hf_cache
else:
patch_unet(ckpt_path, args.ed1_mode)
return ckpt_path
def setup_local_logger(args):
@ -205,6 +206,9 @@ def main(args):
torch.backends.cudnn.benchmark = False
if args.ed1_mode:
args.disable_xformers = True
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
@ -379,8 +383,8 @@ def main(args):
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
#unet.upcast_attention(True)
scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
except:
logging.ERROR(" * Failed to load checkpoint *")
@ -417,18 +421,14 @@ def main(args):
epsilon = 1e-8 if not args.amp else 1e-8
weight_decay = 0.01
if args.useadam8bit:
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
itertools.chain(params_to_train),
lr=curr_lr,
betas=betas,
eps=epsilon,
weight_decay=weight_decay,
)
opt_class = bnb.optim.AdamW8bit
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
else:
opt_class = torch.optim.AdamW
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
optimizer = torch.optim.AdamW(
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=betas,
@ -511,7 +511,7 @@ def main(args):
"""
global interrupted
if not interrupted:
interrupted=True
interrupted=True
global global_step
#TODO: save model on ctrl-c
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
@ -519,7 +519,8 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
exit(_SIGTERM_EXIT_CODE)
signal.signal(signal.SIGINT, sigterm_handler)
@ -577,7 +578,7 @@ def main(args):
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
logging.info(f" vae device: {vae.device}, precision: {vae.dtype}, training: {vae.training}")
logging.info(f" scheduler: {scheduler.__class__}")
logging.info(f" scheduler: {noise_scheduler.__class__}")
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
@ -622,7 +623,7 @@ def main(args):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
@ -634,15 +635,15 @@ def main(args):
encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip]
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
if scheduler.config.prediction_type == "epsilon":
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = scheduler.get_velocity(latents, noise, timesteps)
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
@ -651,7 +652,7 @@ def main(args):
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
if batch["runt_size"] > 0:
@ -677,6 +678,7 @@ def main(args):
lr_scheduler.step()
steps_pbar.set_postfix({"gs": global_step})
steps_pbar.update(1)
global_step += 1
@ -698,7 +700,7 @@ def main(args):
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae)
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)
with torch.no_grad():
@ -719,12 +721,12 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
del loss, batch
# end of step
@ -741,7 +743,7 @@ def main(args):
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -751,7 +753,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -784,6 +786,7 @@ if __name__ == "__main__":
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")