clean line
This commit is contained in:
parent
8fdd151991
commit
906d282023
61
train.py
61
train.py
|
@ -91,12 +91,13 @@ def convert_to_hf(ckpt_path):
|
|||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
patch_unet(hf_cache)
|
||||
patch_unet(hf_cache, args.ed1_mode)
|
||||
return hf_cache
|
||||
elif os.path.isdir(hf_cache):
|
||||
patch_unet(hf_cache)
|
||||
patch_unet(hf_cache, args.ed1_mode)
|
||||
return hf_cache
|
||||
else:
|
||||
patch_unet(ckpt_path, args.ed1_mode)
|
||||
return ckpt_path
|
||||
|
||||
def setup_local_logger(args):
|
||||
|
@ -205,6 +206,9 @@ def main(args):
|
|||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
if args.ed1_mode:
|
||||
args.disable_xformers = True
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
|
@ -379,8 +383,8 @@ def main(args):
|
|||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
|
||||
#unet.upcast_attention(True)
|
||||
scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||
except:
|
||||
logging.ERROR(" * Failed to load checkpoint *")
|
||||
|
@ -417,18 +421,14 @@ def main(args):
|
|||
epsilon = 1e-8 if not args.amp else 1e-8
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
|
||||
import bitsandbytes as bnb
|
||||
optimizer = bnb.optim.AdamW8bit(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=betas,
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
|
||||
else:
|
||||
opt_class = torch.optim.AdamW
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
|
||||
optimizer = torch.optim.AdamW(
|
||||
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=betas,
|
||||
|
@ -511,7 +511,7 @@ def main(args):
|
|||
"""
|
||||
global interrupted
|
||||
if not interrupted:
|
||||
interrupted=True
|
||||
interrupted=True
|
||||
global global_step
|
||||
#TODO: save model on ctrl-c
|
||||
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
||||
|
@ -519,7 +519,8 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
|
||||
signal.signal(signal.SIGINT, sigterm_handler)
|
||||
|
@ -577,7 +578,7 @@ def main(args):
|
|||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
||||
logging.info(f" vae device: {vae.device}, precision: {vae.dtype}, training: {vae.training}")
|
||||
logging.info(f" scheduler: {scheduler.__class__}")
|
||||
logging.info(f" scheduler: {noise_scheduler.__class__}")
|
||||
|
||||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||
|
@ -622,7 +623,7 @@ def main(args):
|
|||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
||||
|
@ -634,15 +635,15 @@ def main(args):
|
|||
encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip]
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if scheduler.config.prediction_type == "epsilon":
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = scheduler.get_velocity(latents, noise, timesteps)
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
|
@ -651,7 +652,7 @@ def main(args):
|
|||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
|
||||
del target, model_pred
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
|
@ -677,6 +678,7 @@ def main(args):
|
|||
|
||||
lr_scheduler.step()
|
||||
|
||||
steps_pbar.set_postfix({"gs": global_step})
|
||||
steps_pbar.update(1)
|
||||
global_step += 1
|
||||
|
||||
|
@ -698,7 +700,7 @@ def main(args):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae)
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -719,12 +721,12 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
|
||||
|
||||
del loss, batch
|
||||
# end of step
|
||||
|
@ -741,7 +743,7 @@ def main(args):
|
|||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -751,7 +753,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
@ -784,6 +786,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
|
|
Loading…
Reference in New Issue