This commit is contained in:
Victor Hall 2023-01-15 22:07:37 -05:00
parent 6ba710d6f1
commit ba25992140
3 changed files with 132 additions and 54 deletions

View File

@ -21,6 +21,7 @@
"lr_scheduler": "constant", "lr_scheduler": "constant",
"lr_warmup_steps": null, "lr_warmup_steps": null,
"max_epochs": 30, "max_epochs": 30,
"notebook": false,
"project_name": "project_abc", "project_name": "project_abc",
"resolution": 512, "resolution": 512,
"resume_ckpt": "sd_v1-5_vae", "resume_ckpt": "sd_v1-5_vae",

128
train.py
View File

@ -222,12 +222,15 @@ def setup_args(args):
Sets defaults for missing args (possible if missing from json config) Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons Forces some args to be set based on others for compatibility reasons
""" """
if args.disable_unet_training and args.disable_textenc_training:
raise ValueError("Both unet and textenc are disabled, nothing to train")
if args.resume_ckpt == "findlast": if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir # find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir) args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers: if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
args.disable_xformers = True args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True") logging.info(" ED1 mode: Overiding disable_xformers to True")
@ -238,7 +241,7 @@ def setup_args(args):
args.shuffle_tags = False args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0) args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20 args.ckpt_every_n_minutes = 20
@ -248,7 +251,7 @@ def setup_args(args):
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1: if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
args.save_every_n_epochs = _VERY_LARGE_NUMBER args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER: if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
@ -269,6 +272,9 @@ def setup_args(args):
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir): if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir) os.makedirs(args.save_ckpt_dir)
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
args.clip_grad_norm = 1.0
if args.rated_dataset: if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100) args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
@ -286,9 +292,11 @@ def main(args):
if args.notebook: if args.notebook:
from tqdm.notebook import tqdm from tqdm.notebook import tqdm
else: else:
from tqdm.auto import tqdm from tqdm.auto import tqdm
logging.info(f" Seed: {args.seed}")
seed = args.seed if args.seed != -1 else random.randint(0, 2**30) seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed) set_seed(seed)
gpu = GPU() gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}") device = torch.device(f"cuda:{args.gpuid}")
@ -441,7 +449,7 @@ def main(args):
hf_ckpt_path = convert_to_hf(args.resume_ckpt) hf_ckpt_path = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet") unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False) tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
@ -468,22 +476,38 @@ def main(args):
default_lr = 2e-6 default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) d_type = torch.float32
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) if args.mixed_precision == "fp16":
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) d_type = torch.float16
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) logging.info(" * Using fp16 *")
unet = unet.to(device, dtype=torch.float32) args.amp = True
text_encoder = text_encoder.to(device, dtype=torch.float32) elif args.mixed_precision == "bf16":
d_type = torch.bfloat16
logging.info(" * Using bf16 *")
args.amp = True
else:
logging.info(" * Using FP32 *")
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
unet = unet.to(device, dtype=d_type)
text_encoder = text_encoder.to(device, dtype=d_type)
if args.disable_textenc_training: if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters()) params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else: else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
betas = (0.9, 0.999) betas = (0.9, 0.999)
epsilon = 1e-8 if not args.amp else 1e-8 epsilon = 1e-8
if args.amp or args.mix_precision == "fp16":
epsilon = 1e-8
weight_decay = 0.01 weight_decay = 0.01
if args.useadam8bit: if args.useadam8bit:
import bitsandbytes as bnb import bitsandbytes as bnb
@ -502,6 +526,8 @@ def main(args):
amsgrad=False, amsgrad=False,
) )
log_optimizer(optimizer, betas, epsilon)
train_batch = EveryDreamBatch( train_batch = EveryDreamBatch(
data_root=args.data_root, data_root=args.data_root,
flip_p=args.flip_p, flip_p=args.flip_p,
@ -540,11 +566,8 @@ def main(args):
sample_prompts.append(line.strip()) sample_prompts.append(line.strip())
if False: #args.wandb is not None and args.wandb: # not yet supported if args.wandb is not None and args.wandb:
log_writer = wandb.init(project="EveryDream2FineTunes", wandb.init(project=args.project_name, sync_tensorboard=True, )
name=args.project_name,
dir=log_folder,
)
else: else:
log_writer = SummaryWriter(log_dir=log_folder, log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5, flush_secs=5,
@ -602,7 +625,6 @@ def main(args):
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
def collate_fn(batch): def collate_fn(batch):
""" """
Collates batches Collates batches
@ -632,7 +654,7 @@ def main(args):
collate_fn=collate_fn collate_fn=collate_fn
) )
unet.train() unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
@ -643,9 +665,20 @@ def main(args):
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"), logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
if args.amp or d_type != torch.float32:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
enabled=False,
#enabled=True,
init_scale=2048.0,
growth_factor=1.5,
backoff_factor=0.707,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
@ -661,20 +694,6 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer) append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
if args.amp:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True,
init_scale=1024.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = [] loss_log_step = []
try: try:
@ -723,8 +742,8 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption del noise, latents, cuda_caption
#with autocast(enabled=args.amp): with autocast(enabled=args.amp or d_type != torch.float32):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
@ -732,15 +751,17 @@ def main(args):
del target, model_pred del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if args.amp: if args.amp:
scaler.scale(loss).backward() scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
if not args.disable_textenc_training:
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if batch["runt_size"] > 0: if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing with torch.no_grad(): # not required? just in case for now, needs more testing
@ -753,7 +774,7 @@ def main(args):
param.grad *= grad_scale param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
if args.amp: if args.amp and d_type == torch.float32:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
@ -779,6 +800,7 @@ def main(args):
loss_log_step = [] loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step) sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = [] images_per_sec_log_step = []
@ -861,17 +883,25 @@ def update_old_args(t_args):
Update old args to new args to deal with json config loading and missing args for compatibility Update old args to new args to deal with json config loading and missing args for compatibility
""" """
if not hasattr(t_args, "shuffle_tags"): if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'") print(f" Config json is missing 'shuffle_tags' flag")
t_args.__dict__["shuffle_tags"] = False t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"): if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision'") print(f" Config json is missing 'save_full_precision' flag")
t_args.__dict__["save_full_precision"] = False t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"): if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook'") print(f" Config json is missing 'notebook' flag")
t_args.__dict__["notebook"] = False t_args.__dict__["notebook"] = False
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "mixed_precision"):
print(f" Config json is missing 'mixed_precision' flag")
t_args.__dict__["mixed_precision"] = "fp32"
if __name__ == "__main__": if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32']
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args() args, _ = argparser.parse_known_args()
@ -881,9 +911,11 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
t_args = argparse.Namespace() t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f)) t_args.__dict__.update(json.load(f))
print(t_args.__dict__)
update_old_args(t_args) # update args to support older configs update_old_args(t_args) # update args to support older configs
print(t_args.__dict__) print(t_args.__dict__)
args = argparser.parse_args(namespace=t_args) args = argparser.parse_args(namespace=t_args)
print(f"mixed_precision: {args.mixed_precision}")
else: else:
print("No config file specified, using command line args") print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
@ -894,7 +926,8 @@ if __name__ == "__main__":
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)") argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
@ -909,6 +942,8 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt") argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
@ -916,6 +951,7 @@ if __name__ == "__main__":
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)") argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
@ -923,8 +959,6 @@ if __name__ == "__main__":
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")

View File

@ -16,24 +16,67 @@ limitations under the License.
import logging import logging
import os import os
import time import time
from colorama import Fore, Style
class LogWrapper(object): from tensorboard import SummaryWriter
import wandb
class LogWrapper():
""" """
singleton for logging singleton for logging
""" """
def __init__(self, log_dir, project_name): def __init__(self, args, wandb=False):
self.log_dir = log_dir self.logdir = args.logdir
self.wandb = wandb
if wandb:
wandb.init(project=args.project_name, sync_tensorboard=True)
else:
self.log_writer = SummaryWriter(log_dir=args.logdir,
flush_secs=5,
comment="EveryDream2FineTunes",
)
start_time = time.strftime("%Y%m%d-%H%M") start_time = time.strftime("%Y%m%d-%H%M")
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt") log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
console = logging.StreamHandler() console = logging.StreamHandler()
self.logger.addHandler(console) self.logger.addHandler(console)
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False) file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
self.logger.addHandler(file) self.logger.addHandler(file)
def __call__(self): def add_image():
return self.logger """
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
else:
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
"""
pass
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
if self.wandb:
wandb.log({tag: img_tensor}, step=global_step)
else:
self.log_writer.add_image(tag, img_tensor, global_step)
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
self.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")