parent
1273624a3c
commit
debcdd2506
100
train.py
100
train.py
|
@ -22,9 +22,10 @@ import logging
|
|||
import time
|
||||
import gc
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from colorama import Fore, Style, Cursor
|
||||
|
@ -46,12 +47,10 @@ from accelerate.utils import set_seed
|
|||
import wandb
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import keyboard
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
forstepTime = time.time()
|
||||
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
@ -87,19 +86,19 @@ def convert_to_hf(ckpt_path):
|
|||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||
except:
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
|
||||
exit()
|
||||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
elif os.path.isdir(hf_cache):
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
else:
|
||||
is_sd1attn = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn
|
||||
is_sd1attn, yaml = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn, yaml
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
|
@ -174,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
|
@ -276,6 +274,28 @@ def setup_args(args):
|
|||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 250 or (epoch >= 2 and step == 1):
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 500 or (epoch >= 4 and step == 1):
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000 or (epoch >= 8 and step == 1):
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 3000 or (epoch >= 15 and step == 1):
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -297,12 +317,12 @@ def main(args):
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -323,17 +343,24 @@ def main(args):
|
|||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
|
||||
if yaml_name:
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
@ -439,7 +466,7 @@ def main(args):
|
|||
del images
|
||||
|
||||
try:
|
||||
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
|
||||
hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
||||
|
@ -453,7 +480,8 @@ def main(args):
|
|||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
if not args.disable_xformers:
|
||||
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
|
@ -487,7 +515,7 @@ def main(args):
|
|||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 1e-8
|
||||
epsilon = 2e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
|
@ -564,7 +592,6 @@ def main(args):
|
|||
log_args(log_writer, args)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
||||
|
@ -650,13 +677,12 @@ def main(args):
|
|||
|
||||
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
#enabled=True,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=1.8,
|
||||
backoff_factor=1.0/1.8,
|
||||
growth_interval=50,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
|
@ -678,6 +704,8 @@ def main(args):
|
|||
|
||||
loss_log_step = []
|
||||
|
||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
try:
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
|
@ -727,16 +755,13 @@ def main(args):
|
|||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
#if args.amp:
|
||||
scaler.scale(loss).backward()
|
||||
#else:
|
||||
# loss.backward()
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
|
@ -792,7 +817,7 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
|
@ -814,23 +839,16 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
||||
if global_step == 500:
|
||||
scaler.set_growth_factor(1.4)
|
||||
scaler.set_backoff_factor(1/1.4)
|
||||
if global_step == 1000:
|
||||
scaler.set_growth_factor(1.2)
|
||||
scaler.set_backoff_factor(1/1.2)
|
||||
scaler.set_growth_interval(100)
|
||||
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -850,7 +868,7 @@ def main(args):
|
|||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -860,7 +878,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
|
Loading…
Reference in New Issue