Update train.py

Revert
This commit is contained in:
nawnie 2023-01-22 00:02:35 -06:00 committed by GitHub
parent 1273624a3c
commit debcdd2506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 47 deletions

100
train.py
View File

@ -22,9 +22,10 @@ import logging
import time import time
import gc import gc
import random import random
import shutil
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms import torchvision.transforms as transforms
from colorama import Fore, Style, Cursor from colorama import Fore, Style, Cursor
@ -46,12 +47,10 @@ from accelerate.utils import set_seed
import wandb import wandb
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import keyboard
from data.every_dream import EveryDreamBatch from data.every_dream import EveryDreamBatch
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU from utils.gpu import GPU
forstepTime = time.time()
_SIGTERM_EXIT_CODE = 130 _SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9 _VERY_LARGE_NUMBER = 1e9
@ -87,19 +86,19 @@ def convert_to_hf(ckpt_path):
import utils.convert_original_stable_diffusion_to_diffusers as convert import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}") convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
except: except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.") logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
exit() exit()
else: else:
logging.info(f"Found cached checkpoint at {hf_cache}") logging.info(f"Found cached checkpoint at {hf_cache}")
is_sd1attn = patch_unet(hf_cache) is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn return hf_cache, is_sd1attn, yaml
elif os.path.isdir(hf_cache): elif os.path.isdir(hf_cache):
is_sd1attn = patch_unet(hf_cache) is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn return hf_cache, is_sd1attn, yaml
else: else:
is_sd1attn = patch_unet(ckpt_path) is_sd1attn, yaml = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn return ckpt_path, is_sd1attn, yaml
def setup_local_logger(args): def setup_local_logger(args):
""" """
@ -174,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
if logs is not None: if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s")
def set_args_12gb(args): def set_args_12gb(args):
@ -276,6 +274,28 @@ def setup_args(args):
return args return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
if global_step == 250 or (epoch >= 2 and step == 1):
factor = 1.8
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 500 or (epoch >= 4 and step == 1):
factor = 1.6
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 1000 or (epoch >= 8 and step == 1):
factor = 1.3
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
if global_step == 3000 or (epoch >= 15 and step == 1):
factor = 1.15
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def main(args): def main(args):
""" """
Main entry point Main entry point
@ -297,12 +317,12 @@ def main(args):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder): if not os.path.exists(log_folder):
os.makedirs(log_folder) os.makedirs(log_folder)
@torch.no_grad() @torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False): def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
""" """
Save the model to disk Save the model to disk
""" """
@ -323,17 +343,24 @@ def main(args):
) )
pipeline.save_pretrained(save_path) pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt" sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None: if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else: else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path) sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}") logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half) converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
if yaml_name:
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag: # if self.save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}") # logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path) # self.save_optimizer(self.ctx.optimizer, optimizer_path)
@ -439,7 +466,7 @@ def main(args):
del images del images
try: try:
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt) hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn) unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
@ -453,7 +480,8 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn): if not args.disable_xformers:
if (args.amp and is_sd1attn) or (not is_sd1attn):
try: try:
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers") logging.info("Enabled xformers")
@ -487,7 +515,7 @@ def main(args):
betas = (0.9, 0.999) betas = (0.9, 0.999)
epsilon = 1e-8 epsilon = 1e-8
if args.amp: if args.amp:
epsilon = 1e-8 epsilon = 2e-8
weight_decay = 0.01 weight_decay = 0.01
if args.useadam8bit: if args.useadam8bit:
@ -564,7 +592,6 @@ def main(args):
log_args(log_writer, args) log_args(log_writer, args)
""" """
Train the model Train the model
@ -650,13 +677,12 @@ def main(args):
#scaler = torch.cuda.amp.GradScaler() #scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler( scaler = GradScaler(
enabled=args.amp, enabled=args.amp,
#enabled=True,
init_scale=2**17.5, init_scale=2**17.5,
growth_factor=1.8, growth_factor=2,
backoff_factor=1.0/1.8, backoff_factor=1.0/2,
growth_interval=50, growth_interval=25,
) )
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
@ -678,6 +704,8 @@ def main(args):
loss_log_step = [] loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
try: try:
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
loss_epoch = [] loss_epoch = []
@ -727,16 +755,13 @@ def main(args):
with autocast(enabled=args.amp): with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents #del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred del target, model_pred
#if args.amp:
scaler.scale(loss).backward() scaler.scale(loss).backward()
#else:
# loss.backward()
if args.clip_grad_norm is not None: if args.clip_grad_norm is not None:
if not args.disable_unet_training: if not args.disable_unet_training:
@ -792,7 +817,7 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0): if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device) pipe = pipe.to(device)
@ -814,23 +839,16 @@ def main(args):
last_epoch_saved_time = time.time() last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1: if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
del batch del batch
global_step += 1 global_step += 1
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
if global_step == 500:
scaler.set_growth_factor(1.4)
scaler.set_backoff_factor(1/1.4)
if global_step == 1000:
scaler.set_growth_factor(1.2)
scaler.set_backoff_factor(1/1.2)
scaler.set_growth_interval(100)
# end of step # end of step
steps_pbar.close() steps_pbar.close()
@ -850,7 +868,7 @@ def main(args):
# end of training # end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
total_elapsed_time = time.time() - training_start_time total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -860,7 +878,7 @@ def main(args):
except Exception as ex: except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
raise ex raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")