This commit is contained in:
Victor Hall 2023-01-15 22:07:37 -05:00
parent 6ba710d6f1
commit ba25992140
3 changed files with 132 additions and 54 deletions

View File

@ -21,6 +21,7 @@
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"project_name": "project_abc",
"resolution": 512,
"resume_ckpt": "sd_v1-5_vae",

128
train.py
View File

@ -222,12 +222,15 @@ def setup_args(args):
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.disable_unet_training and args.disable_textenc_training:
raise ValueError("Both unet and textenc are disabled, nothing to train")
if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers:
if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True")
@ -238,7 +241,7 @@ def setup_args(args):
args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
@ -248,7 +251,7 @@ def setup_args(args):
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
@ -269,6 +272,9 @@ def setup_args(args):
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
args.clip_grad_norm = 1.0
if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
@ -286,9 +292,11 @@ def main(args):
if args.notebook:
from tqdm.notebook import tqdm
else:
from tqdm.auto import tqdm
from tqdm.auto import tqdm
logging.info(f" Seed: {args.seed}")
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
@ -441,7 +449,7 @@ def main(args):
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
@ -468,22 +476,38 @@ def main(args):
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=torch.float32)
d_type = torch.float32
if args.mixed_precision == "fp16":
d_type = torch.float16
logging.info(" * Using fp16 *")
args.amp = True
elif args.mixed_precision == "bf16":
d_type = torch.bfloat16
logging.info(" * Using bf16 *")
args.amp = True
else:
logging.info(" * Using FP32 *")
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
unet = unet.to(device, dtype=d_type)
text_encoder = text_encoder.to(device, dtype=d_type)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
betas = (0.9, 0.999)
epsilon = 1e-8 if not args.amp else 1e-8
epsilon = 1e-8
if args.amp or args.mix_precision == "fp16":
epsilon = 1e-8
weight_decay = 0.01
if args.useadam8bit:
import bitsandbytes as bnb
@ -502,6 +526,8 @@ def main(args):
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
train_batch = EveryDreamBatch(
data_root=args.data_root,
flip_p=args.flip_p,
@ -540,11 +566,8 @@ def main(args):
sample_prompts.append(line.strip())
if False: #args.wandb is not None and args.wandb: # not yet supported
log_writer = wandb.init(project="EveryDream2FineTunes",
name=args.project_name,
dir=log_folder,
)
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
else:
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
@ -602,7 +625,6 @@ def main(args):
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
def collate_fn(batch):
"""
Collates batches
@ -632,7 +654,7 @@ def main(args):
collate_fn=collate_fn
)
unet.train()
unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
@ -643,9 +665,20 @@ def main(args):
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
if args.amp or d_type != torch.float32:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
enabled=False,
#enabled=True,
init_scale=2048.0,
growth_factor=1.5,
backoff_factor=0.707,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
@ -661,20 +694,6 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
if args.amp:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True,
init_scale=1024.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = []
try:
@ -723,8 +742,8 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
#with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with autocast(enabled=args.amp or d_type != torch.float32):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
@ -732,15 +751,17 @@ def main(args):
del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if args.amp:
scaler.scale(loss).backward()
else:
loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
if not args.disable_textenc_training:
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing
@ -753,7 +774,7 @@ def main(args):
param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
if args.amp:
if args.amp and d_type == torch.float32:
scaler.step(optimizer)
scaler.update()
else:
@ -779,6 +800,7 @@ def main(args):
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = []
@ -861,17 +883,25 @@ def update_old_args(t_args):
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'")
print(f" Config json is missing 'shuffle_tags' flag")
t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision'")
print(f" Config json is missing 'save_full_precision' flag")
t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook'")
print(f" Config json is missing 'notebook' flag")
t_args.__dict__["notebook"] = False
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "mixed_precision"):
print(f" Config json is missing 'mixed_precision' flag")
t_args.__dict__["mixed_precision"] = "fp32"
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32']
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args()
@ -881,9 +911,11 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f))
print(t_args.__dict__)
update_old_args(t_args) # update args to support older configs
print(t_args.__dict__)
args = argparser.parse_args(namespace=t_args)
print(f"mixed_precision: {args.mixed_precision}")
else:
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
@ -894,7 +926,8 @@ if __name__ == "__main__":
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
@ -909,6 +942,8 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
@ -916,6 +951,7 @@ if __name__ == "__main__":
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
@ -923,8 +959,6 @@ if __name__ == "__main__":
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")

View File

@ -16,24 +16,67 @@ limitations under the License.
import logging
import os
import time
from colorama import Fore, Style
class LogWrapper(object):
from tensorboard import SummaryWriter
import wandb
class LogWrapper():
"""
singleton for logging
"""
def __init__(self, log_dir, project_name):
self.log_dir = log_dir
def __init__(self, args, wandb=False):
self.logdir = args.logdir
self.wandb = wandb
if wandb:
wandb.init(project=args.project_name, sync_tensorboard=True)
else:
self.log_writer = SummaryWriter(log_dir=args.logdir,
flush_secs=5,
comment="EveryDream2FineTunes",
)
start_time = time.strftime("%Y%m%d-%H%M")
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt")
log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
self.logger = logging.getLogger(__name__)
console = logging.StreamHandler()
self.logger.addHandler(console)
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False)
file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
self.logger.addHandler(file)
def __call__(self):
return self.logger
def add_image():
"""
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
else:
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
"""
pass
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
if self.wandb:
wandb.log({tag: img_tensor}, step=global_step)
else:
self.log_writer.add_image(tag, img_tensor, global_step)
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
self.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")