merge
This commit is contained in:
parent
6ba710d6f1
commit
ba25992140
|
@ -21,6 +21,7 @@
|
||||||
"lr_scheduler": "constant",
|
"lr_scheduler": "constant",
|
||||||
"lr_warmup_steps": null,
|
"lr_warmup_steps": null,
|
||||||
"max_epochs": 30,
|
"max_epochs": 30,
|
||||||
|
"notebook": false,
|
||||||
"project_name": "project_abc",
|
"project_name": "project_abc",
|
||||||
"resolution": 512,
|
"resolution": 512,
|
||||||
"resume_ckpt": "sd_v1-5_vae",
|
"resume_ckpt": "sd_v1-5_vae",
|
||||||
|
|
128
train.py
128
train.py
|
@ -222,12 +222,15 @@ def setup_args(args):
|
||||||
Sets defaults for missing args (possible if missing from json config)
|
Sets defaults for missing args (possible if missing from json config)
|
||||||
Forces some args to be set based on others for compatibility reasons
|
Forces some args to be set based on others for compatibility reasons
|
||||||
"""
|
"""
|
||||||
|
if args.disable_unet_training and args.disable_textenc_training:
|
||||||
|
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||||
|
|
||||||
if args.resume_ckpt == "findlast":
|
if args.resume_ckpt == "findlast":
|
||||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||||
# find the last checkpoint in the logdir
|
# find the last checkpoint in the logdir
|
||||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||||
|
|
||||||
if args.ed1_mode and not args.disable_xformers:
|
if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
|
||||||
args.disable_xformers = True
|
args.disable_xformers = True
|
||||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
||||||
|
|
||||||
|
@ -238,7 +241,7 @@ def setup_args(args):
|
||||||
args.shuffle_tags = False
|
args.shuffle_tags = False
|
||||||
|
|
||||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||||
|
|
||||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||||
args.ckpt_every_n_minutes = 20
|
args.ckpt_every_n_minutes = 20
|
||||||
|
@ -248,7 +251,7 @@ def setup_args(args):
|
||||||
|
|
||||||
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
||||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||||
|
|
||||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||||
|
@ -269,6 +272,9 @@ def setup_args(args):
|
||||||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||||
os.makedirs(args.save_ckpt_dir)
|
os.makedirs(args.save_ckpt_dir)
|
||||||
|
|
||||||
|
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
|
||||||
|
args.clip_grad_norm = 1.0
|
||||||
|
|
||||||
if args.rated_dataset:
|
if args.rated_dataset:
|
||||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||||
|
|
||||||
|
@ -286,9 +292,11 @@ def main(args):
|
||||||
if args.notebook:
|
if args.notebook:
|
||||||
from tqdm.notebook import tqdm
|
from tqdm.notebook import tqdm
|
||||||
else:
|
else:
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
logging.info(f" Seed: {args.seed}")
|
||||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||||
|
logging.info(f" Seed: {seed}")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
gpu = GPU()
|
gpu = GPU()
|
||||||
device = torch.device(f"cuda:{args.gpuid}")
|
device = torch.device(f"cuda:{args.gpuid}")
|
||||||
|
@ -441,7 +449,7 @@ def main(args):
|
||||||
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
|
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
|
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
|
||||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||||
|
@ -468,22 +476,38 @@ def main(args):
|
||||||
default_lr = 2e-6
|
default_lr = 2e-6
|
||||||
curr_lr = args.lr if args.lr is not None else default_lr
|
curr_lr = args.lr if args.lr is not None else default_lr
|
||||||
|
|
||||||
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
d_type = torch.float32
|
||||||
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
if args.mixed_precision == "fp16":
|
||||||
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
d_type = torch.float16
|
||||||
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
logging.info(" * Using fp16 *")
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
args.amp = True
|
||||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
elif args.mixed_precision == "bf16":
|
||||||
|
d_type = torch.bfloat16
|
||||||
|
logging.info(" * Using bf16 *")
|
||||||
|
args.amp = True
|
||||||
|
else:
|
||||||
|
logging.info(" * Using FP32 *")
|
||||||
|
|
||||||
|
|
||||||
|
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
|
||||||
|
unet = unet.to(device, dtype=d_type)
|
||||||
|
text_encoder = text_encoder.to(device, dtype=d_type)
|
||||||
|
|
||||||
if args.disable_textenc_training:
|
if args.disable_textenc_training:
|
||||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters())
|
params_to_train = itertools.chain(unet.parameters())
|
||||||
|
elif args.disable_unet_training:
|
||||||
|
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||||
|
params_to_train = itertools.chain(text_encoder.parameters())
|
||||||
else:
|
else:
|
||||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
epsilon = 1e-8 if not args.amp else 1e-8
|
epsilon = 1e-8
|
||||||
|
if args.amp or args.mix_precision == "fp16":
|
||||||
|
epsilon = 1e-8
|
||||||
|
|
||||||
weight_decay = 0.01
|
weight_decay = 0.01
|
||||||
if args.useadam8bit:
|
if args.useadam8bit:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -502,6 +526,8 @@ def main(args):
|
||||||
amsgrad=False,
|
amsgrad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_optimizer(optimizer, betas, epsilon)
|
||||||
|
|
||||||
train_batch = EveryDreamBatch(
|
train_batch = EveryDreamBatch(
|
||||||
data_root=args.data_root,
|
data_root=args.data_root,
|
||||||
flip_p=args.flip_p,
|
flip_p=args.flip_p,
|
||||||
|
@ -540,11 +566,8 @@ def main(args):
|
||||||
sample_prompts.append(line.strip())
|
sample_prompts.append(line.strip())
|
||||||
|
|
||||||
|
|
||||||
if False: #args.wandb is not None and args.wandb: # not yet supported
|
if args.wandb is not None and args.wandb:
|
||||||
log_writer = wandb.init(project="EveryDream2FineTunes",
|
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||||
name=args.project_name,
|
|
||||||
dir=log_folder,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log_writer = SummaryWriter(log_dir=log_folder,
|
log_writer = SummaryWriter(log_dir=log_folder,
|
||||||
flush_secs=5,
|
flush_secs=5,
|
||||||
|
@ -602,7 +625,6 @@ def main(args):
|
||||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
"""
|
"""
|
||||||
Collates batches
|
Collates batches
|
||||||
|
@ -632,7 +654,7 @@ def main(args):
|
||||||
collate_fn=collate_fn
|
collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
unet.train()
|
unet.train() if not args.disable_unet_training else unet.eval()
|
||||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||||
|
|
||||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||||
|
@ -643,9 +665,20 @@ def main(args):
|
||||||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||||
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
|
|
||||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
if args.amp or d_type != torch.float32:
|
||||||
|
#scaler = torch.cuda.amp.GradScaler()
|
||||||
|
scaler = torch.cuda.amp.GradScaler(
|
||||||
|
enabled=False,
|
||||||
|
#enabled=True,
|
||||||
|
init_scale=2048.0,
|
||||||
|
growth_factor=1.5,
|
||||||
|
backoff_factor=0.707,
|
||||||
|
growth_interval=50,
|
||||||
|
)
|
||||||
|
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||||
|
|
||||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
@ -661,20 +694,6 @@ def main(args):
|
||||||
|
|
||||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
||||||
|
|
||||||
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
if args.amp:
|
|
||||||
#scaler = torch.cuda.amp.GradScaler()
|
|
||||||
scaler = torch.cuda.amp.GradScaler(
|
|
||||||
#enabled=False,
|
|
||||||
enabled=True,
|
|
||||||
init_scale=1024.0,
|
|
||||||
growth_factor=2.0,
|
|
||||||
backoff_factor=0.5,
|
|
||||||
growth_interval=50,
|
|
||||||
)
|
|
||||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
|
||||||
|
|
||||||
loss_log_step = []
|
loss_log_step = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -723,8 +742,8 @@ def main(args):
|
||||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
del noise, latents, cuda_caption
|
del noise, latents, cuda_caption
|
||||||
|
|
||||||
#with autocast(enabled=args.amp):
|
with autocast(enabled=args.amp or d_type != torch.float32):
|
||||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
del timesteps, encoder_hidden_states, noisy_latents
|
del timesteps, encoder_hidden_states, noisy_latents
|
||||||
#with autocast(enabled=args.amp):
|
#with autocast(enabled=args.amp):
|
||||||
|
@ -732,15 +751,17 @@ def main(args):
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
||||||
if args.clip_grad_norm is not None:
|
|
||||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
|
||||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
|
||||||
|
|
||||||
if args.amp:
|
if args.amp:
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
if args.clip_grad_norm is not None:
|
||||||
|
if not args.disable_unet_training:
|
||||||
|
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||||
|
if not args.disable_textenc_training:
|
||||||
|
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||||
|
|
||||||
if batch["runt_size"] > 0:
|
if batch["runt_size"] > 0:
|
||||||
grad_scale = batch["runt_size"] / args.batch_size
|
grad_scale = batch["runt_size"] / args.batch_size
|
||||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||||
|
@ -753,7 +774,7 @@ def main(args):
|
||||||
param.grad *= grad_scale
|
param.grad *= grad_scale
|
||||||
|
|
||||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||||
if args.amp:
|
if args.amp and d_type == torch.float32:
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
|
@ -779,6 +800,7 @@ def main(args):
|
||||||
loss_log_step = []
|
loss_log_step = []
|
||||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||||
|
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||||
sum_img = sum(images_per_sec_log_step)
|
sum_img = sum(images_per_sec_log_step)
|
||||||
avg = sum_img / len(images_per_sec_log_step)
|
avg = sum_img / len(images_per_sec_log_step)
|
||||||
images_per_sec_log_step = []
|
images_per_sec_log_step = []
|
||||||
|
@ -861,17 +883,25 @@ def update_old_args(t_args):
|
||||||
Update old args to new args to deal with json config loading and missing args for compatibility
|
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||||
"""
|
"""
|
||||||
if not hasattr(t_args, "shuffle_tags"):
|
if not hasattr(t_args, "shuffle_tags"):
|
||||||
print(f" Config json is missing 'shuffle_tags'")
|
print(f" Config json is missing 'shuffle_tags' flag")
|
||||||
t_args.__dict__["shuffle_tags"] = False
|
t_args.__dict__["shuffle_tags"] = False
|
||||||
if not hasattr(t_args, "save_full_precision"):
|
if not hasattr(t_args, "save_full_precision"):
|
||||||
print(f" Config json is missing 'save_full_precision'")
|
print(f" Config json is missing 'save_full_precision' flag")
|
||||||
t_args.__dict__["save_full_precision"] = False
|
t_args.__dict__["save_full_precision"] = False
|
||||||
if not hasattr(t_args, "notebook"):
|
if not hasattr(t_args, "notebook"):
|
||||||
print(f" Config json is missing 'notebook'")
|
print(f" Config json is missing 'notebook' flag")
|
||||||
t_args.__dict__["notebook"] = False
|
t_args.__dict__["notebook"] = False
|
||||||
|
if not hasattr(t_args, "disable_unet_training"):
|
||||||
|
print(f" Config json is missing 'disable_unet_training' flag")
|
||||||
|
t_args.__dict__["disable_unet_training"] = False
|
||||||
|
if not hasattr(t_args, "mixed_precision"):
|
||||||
|
print(f" Config json is missing 'mixed_precision' flag")
|
||||||
|
t_args.__dict__["mixed_precision"] = "fp32"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||||
|
supported_precisions = ['fp16', 'fp32']
|
||||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||||
args, _ = argparser.parse_known_args()
|
args, _ = argparser.parse_known_args()
|
||||||
|
@ -881,9 +911,11 @@ if __name__ == "__main__":
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
t_args = argparse.Namespace()
|
t_args = argparse.Namespace()
|
||||||
t_args.__dict__.update(json.load(f))
|
t_args.__dict__.update(json.load(f))
|
||||||
|
print(t_args.__dict__)
|
||||||
update_old_args(t_args) # update args to support older configs
|
update_old_args(t_args) # update args to support older configs
|
||||||
print(t_args.__dict__)
|
print(t_args.__dict__)
|
||||||
args = argparser.parse_args(namespace=t_args)
|
args = argparser.parse_args(namespace=t_args)
|
||||||
|
print(f"mixed_precision: {args.mixed_precision}")
|
||||||
else:
|
else:
|
||||||
print("No config file specified, using command line args")
|
print("No config file specified, using command line args")
|
||||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||||
|
@ -894,7 +926,8 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
|
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||||
|
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||||
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
||||||
|
@ -909,6 +942,8 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||||
|
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
|
||||||
|
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
|
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
|
||||||
|
@ -916,6 +951,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||||
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||||
|
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||||
|
@ -923,8 +959,6 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
|
||||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
|
||||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||||
|
|
||||||
|
|
|
@ -16,24 +16,67 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from colorama import Fore, Style
|
||||||
|
|
||||||
class LogWrapper(object):
|
from tensorboard import SummaryWriter
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
class LogWrapper():
|
||||||
"""
|
"""
|
||||||
singleton for logging
|
singleton for logging
|
||||||
"""
|
"""
|
||||||
def __init__(self, log_dir, project_name):
|
def __init__(self, args, wandb=False):
|
||||||
self.log_dir = log_dir
|
self.logdir = args.logdir
|
||||||
|
self.wandb = wandb
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
wandb.init(project=args.project_name, sync_tensorboard=True)
|
||||||
|
else:
|
||||||
|
self.log_writer = SummaryWriter(log_dir=args.logdir,
|
||||||
|
flush_secs=5,
|
||||||
|
comment="EveryDream2FineTunes",
|
||||||
|
)
|
||||||
|
|
||||||
start_time = time.strftime("%Y%m%d-%H%M")
|
start_time = time.strftime("%Y%m%d-%H%M")
|
||||||
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt")
|
log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
console = logging.StreamHandler()
|
console = logging.StreamHandler()
|
||||||
self.logger.addHandler(console)
|
self.logger.addHandler(console)
|
||||||
|
|
||||||
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False)
|
file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
|
||||||
self.logger.addHandler(file)
|
self.logger.addHandler(file)
|
||||||
|
|
||||||
def __call__(self):
|
def add_image():
|
||||||
return self.logger
|
"""
|
||||||
|
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||||
|
else:
|
||||||
|
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
|
||||||
|
if self.wandb:
|
||||||
|
wandb.log({tag: img_tensor}, step=global_step)
|
||||||
|
else:
|
||||||
|
self.log_writer.add_image(tag, img_tensor, global_step)
|
||||||
|
|
||||||
|
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||||
|
"""
|
||||||
|
updates the vram usage for the epoch
|
||||||
|
"""
|
||||||
|
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||||
|
self.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||||
|
epoch_mem_color = Style.RESET_ALL
|
||||||
|
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||||
|
epoch_mem_color = Fore.LIGHTRED_EX
|
||||||
|
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||||
|
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||||
|
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||||
|
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||||
|
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||||
|
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||||
|
|
||||||
|
if logs is not None:
|
||||||
|
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
Loading…
Reference in New Issue