parent
1273624a3c
commit
debcdd2506
100
train.py
100
train.py
|
@ -22,9 +22,10 @@ import logging
|
||||||
import time
|
import time
|
||||||
import gc
|
import gc
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
from colorama import Fore, Style, Cursor
|
from colorama import Fore, Style, Cursor
|
||||||
|
@ -46,12 +47,10 @@ from accelerate.utils import set_seed
|
||||||
import wandb
|
import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
import keyboard
|
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch
|
from data.every_dream import EveryDreamBatch
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.gpu import GPU
|
from utils.gpu import GPU
|
||||||
forstepTime = time.time()
|
|
||||||
|
|
||||||
_SIGTERM_EXIT_CODE = 130
|
_SIGTERM_EXIT_CODE = 130
|
||||||
_VERY_LARGE_NUMBER = 1e9
|
_VERY_LARGE_NUMBER = 1e9
|
||||||
|
@ -87,19 +86,19 @@ def convert_to_hf(ckpt_path):
|
||||||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||||
except:
|
except:
|
||||||
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
|
logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||||
|
|
||||||
is_sd1attn = patch_unet(hf_cache)
|
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||||
return hf_cache, is_sd1attn
|
return hf_cache, is_sd1attn, yaml
|
||||||
elif os.path.isdir(hf_cache):
|
elif os.path.isdir(hf_cache):
|
||||||
is_sd1attn = patch_unet(hf_cache)
|
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||||
return hf_cache, is_sd1attn
|
return hf_cache, is_sd1attn, yaml
|
||||||
else:
|
else:
|
||||||
is_sd1attn = patch_unet(ckpt_path)
|
is_sd1attn, yaml = patch_unet(ckpt_path)
|
||||||
return ckpt_path, is_sd1attn
|
return ckpt_path, is_sd1attn, yaml
|
||||||
|
|
||||||
def setup_local_logger(args):
|
def setup_local_logger(args):
|
||||||
"""
|
"""
|
||||||
|
@ -174,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||||
|
|
||||||
if logs is not None:
|
if logs is not None:
|
||||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||||
print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s")
|
|
||||||
|
|
||||||
|
|
||||||
def set_args_12gb(args):
|
def set_args_12gb(args):
|
||||||
|
@ -276,6 +274,28 @@ def setup_args(args):
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
|
if global_step == 250 or (epoch >= 2 and step == 1):
|
||||||
|
factor = 1.8
|
||||||
|
scaler.set_growth_factor(factor)
|
||||||
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
scaler.set_growth_interval(50)
|
||||||
|
if global_step == 500 or (epoch >= 4 and step == 1):
|
||||||
|
factor = 1.6
|
||||||
|
scaler.set_growth_factor(factor)
|
||||||
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
scaler.set_growth_interval(50)
|
||||||
|
if global_step == 1000 or (epoch >= 8 and step == 1):
|
||||||
|
factor = 1.3
|
||||||
|
scaler.set_growth_factor(factor)
|
||||||
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
scaler.set_growth_interval(100)
|
||||||
|
if global_step == 3000 or (epoch >= 15 and step == 1):
|
||||||
|
factor = 1.15
|
||||||
|
scaler.set_growth_factor(factor)
|
||||||
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
scaler.set_growth_interval(100)
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
Main entry point
|
Main entry point
|
||||||
|
@ -297,12 +317,12 @@ def main(args):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||||
logging.info(f"Logging to {log_folder}")
|
|
||||||
if not os.path.exists(log_folder):
|
if not os.path.exists(log_folder):
|
||||||
os.makedirs(log_folder)
|
os.makedirs(log_folder)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
|
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||||
"""
|
"""
|
||||||
Save the model to disk
|
Save the model to disk
|
||||||
"""
|
"""
|
||||||
|
@ -323,17 +343,24 @@ def main(args):
|
||||||
)
|
)
|
||||||
pipeline.save_pretrained(save_path)
|
pipeline.save_pretrained(save_path)
|
||||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||||
|
|
||||||
if save_ckpt_dir is not None:
|
if save_ckpt_dir is not None:
|
||||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||||
else:
|
else:
|
||||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||||
|
save_ckpt_dir = os.curdir
|
||||||
|
|
||||||
half = not save_full_precision
|
half = not save_full_precision
|
||||||
|
|
||||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
|
||||||
|
|
||||||
|
if yaml_name:
|
||||||
|
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||||
|
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||||
|
shutil.copyfile(yaml_name, yaml_save_path)
|
||||||
|
|
||||||
|
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||||
# if self.save_optimizer_flag:
|
# if self.save_optimizer_flag:
|
||||||
# logging.info(f" Saving optimizer state to {save_path}")
|
# logging.info(f" Saving optimizer state to {save_path}")
|
||||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||||
|
@ -439,7 +466,7 @@ def main(args):
|
||||||
del images
|
del images
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
|
hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
||||||
|
@ -453,7 +480,8 @@ def main(args):
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
|
if not args.disable_xformers:
|
||||||
|
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||||
try:
|
try:
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
logging.info("Enabled xformers")
|
logging.info("Enabled xformers")
|
||||||
|
@ -487,7 +515,7 @@ def main(args):
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
epsilon = 1e-8
|
epsilon = 1e-8
|
||||||
if args.amp:
|
if args.amp:
|
||||||
epsilon = 1e-8
|
epsilon = 2e-8
|
||||||
|
|
||||||
weight_decay = 0.01
|
weight_decay = 0.01
|
||||||
if args.useadam8bit:
|
if args.useadam8bit:
|
||||||
|
@ -564,7 +592,6 @@ def main(args):
|
||||||
log_args(log_writer, args)
|
log_args(log_writer, args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Train the model
|
Train the model
|
||||||
|
|
||||||
|
@ -650,13 +677,12 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
#scaler = torch.cuda.amp.GradScaler()
|
#scaler = torch.cuda.amp.GradScaler()
|
||||||
scaler = torch.cuda.amp.GradScaler(
|
scaler = GradScaler(
|
||||||
enabled=args.amp,
|
enabled=args.amp,
|
||||||
#enabled=True,
|
|
||||||
init_scale=2**17.5,
|
init_scale=2**17.5,
|
||||||
growth_factor=1.8,
|
growth_factor=2,
|
||||||
backoff_factor=1.0/1.8,
|
backoff_factor=1.0/2,
|
||||||
growth_interval=50,
|
growth_interval=25,
|
||||||
)
|
)
|
||||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||||
|
|
||||||
|
@ -678,6 +704,8 @@ def main(args):
|
||||||
|
|
||||||
loss_log_step = []
|
loss_log_step = []
|
||||||
|
|
||||||
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
|
@ -727,16 +755,13 @@ def main(args):
|
||||||
with autocast(enabled=args.amp):
|
with autocast(enabled=args.amp):
|
||||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
del timesteps, encoder_hidden_states, noisy_latents
|
#del timesteps, encoder_hidden_states, noisy_latents
|
||||||
#with autocast(enabled=args.amp):
|
#with autocast(enabled=args.amp):
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
||||||
#if args.amp:
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
#else:
|
|
||||||
# loss.backward()
|
|
||||||
|
|
||||||
if args.clip_grad_norm is not None:
|
if args.clip_grad_norm is not None:
|
||||||
if not args.disable_unet_training:
|
if not args.disable_unet_training:
|
||||||
|
@ -792,7 +817,7 @@ def main(args):
|
||||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
|
if (global_step + 1) % args.sample_steps == 0:
|
||||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
|
@ -814,23 +839,16 @@ def main(args):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||||
if global_step == 500:
|
|
||||||
scaler.set_growth_factor(1.4)
|
|
||||||
scaler.set_backoff_factor(1/1.4)
|
|
||||||
if global_step == 1000:
|
|
||||||
scaler.set_growth_factor(1.2)
|
|
||||||
scaler.set_backoff_factor(1/1.2)
|
|
||||||
scaler.set_growth_interval(100)
|
|
||||||
# end of step
|
# end of step
|
||||||
|
|
||||||
steps_pbar.close()
|
steps_pbar.close()
|
||||||
|
@ -850,7 +868,7 @@ def main(args):
|
||||||
# end of training
|
# end of training
|
||||||
|
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
|
|
||||||
total_elapsed_time = time.time() - training_start_time
|
total_elapsed_time = time.time() - training_start_time
|
||||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||||
|
@ -860,7 +878,7 @@ def main(args):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||||
|
|
Loading…
Reference in New Issue