merge
This commit is contained in:
parent
6ba710d6f1
commit
ba25992140
|
@ -21,6 +21,7 @@
|
|||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"project_name": "project_abc",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "sd_v1-5_vae",
|
||||
|
|
128
train.py
128
train.py
|
@ -222,12 +222,15 @@ def setup_args(args):
|
|||
Sets defaults for missing args (possible if missing from json config)
|
||||
Forces some args to be set based on others for compatibility reasons
|
||||
"""
|
||||
if args.disable_unet_training and args.disable_textenc_training:
|
||||
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||
|
||||
if args.resume_ckpt == "findlast":
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.ed1_mode and not args.disable_xformers:
|
||||
if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers:
|
||||
args.disable_xformers = True
|
||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
||||
|
||||
|
@ -238,7 +241,7 @@ def setup_args(args):
|
|||
args.shuffle_tags = False
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
@ -248,7 +251,7 @@ def setup_args(args):
|
|||
|
||||
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||
|
||||
|
||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||
|
@ -269,6 +272,9 @@ def setup_args(args):
|
|||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0):
|
||||
args.clip_grad_norm = 1.0
|
||||
|
||||
if args.rated_dataset:
|
||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||
|
||||
|
@ -286,9 +292,11 @@ def main(args):
|
|||
if args.notebook:
|
||||
from tqdm.notebook import tqdm
|
||||
else:
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logging.info(f" Seed: {args.seed}")
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
|
@ -441,7 +449,7 @@ def main(args):
|
|||
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||
|
@ -468,22 +476,38 @@ def main(args):
|
|||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
d_type = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
d_type = torch.float16
|
||||
logging.info(" * Using fp16 *")
|
||||
args.amp = True
|
||||
elif args.mixed_precision == "bf16":
|
||||
d_type = torch.bfloat16
|
||||
logging.info(" * Using bf16 *")
|
||||
args.amp = True
|
||||
else:
|
||||
logging.info(" * Using FP32 *")
|
||||
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type)
|
||||
unet = unet.to(device, dtype=d_type)
|
||||
text_encoder = text_encoder.to(device, dtype=d_type)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8 if not args.amp else 1e-8
|
||||
epsilon = 1e-8
|
||||
if args.amp or args.mix_precision == "fp16":
|
||||
epsilon = 1e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -502,6 +526,8 @@ def main(args):
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
data_root=args.data_root,
|
||||
flip_p=args.flip_p,
|
||||
|
@ -540,11 +566,8 @@ def main(args):
|
|||
sample_prompts.append(line.strip())
|
||||
|
||||
|
||||
if False: #args.wandb is not None and args.wandb: # not yet supported
|
||||
log_writer = wandb.init(project="EveryDream2FineTunes",
|
||||
name=args.project_name,
|
||||
dir=log_folder,
|
||||
)
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
else:
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
|
@ -602,7 +625,6 @@ def main(args):
|
|||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
|
@ -632,7 +654,7 @@ def main(args):
|
|||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
unet.train()
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||
|
@ -643,9 +665,20 @@ def main(args):
|
|||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
if args.amp or d_type != torch.float32:
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
enabled=False,
|
||||
#enabled=True,
|
||||
init_scale=2048.0,
|
||||
growth_factor=1.5,
|
||||
backoff_factor=0.707,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
||||
|
@ -661,20 +694,6 @@ def main(args):
|
|||
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
||||
|
||||
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
|
||||
|
||||
if args.amp:
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
#enabled=False,
|
||||
enabled=True,
|
||||
init_scale=1024.0,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||
|
||||
loss_log_step = []
|
||||
|
||||
try:
|
||||
|
@ -723,8 +742,8 @@ def main(args):
|
|||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with autocast(enabled=args.amp or d_type != torch.float32):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
|
@ -732,15 +751,17 @@ def main(args):
|
|||
|
||||
del target, model_pred
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if args.amp:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
grad_scale = batch["runt_size"] / args.batch_size
|
||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||
|
@ -753,7 +774,7 @@ def main(args):
|
|||
param.grad *= grad_scale
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
if args.amp:
|
||||
if args.amp and d_type == torch.float32:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
|
@ -779,6 +800,7 @@ def main(args):
|
|||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
|
@ -861,17 +883,25 @@ def update_old_args(t_args):
|
|||
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||
"""
|
||||
if not hasattr(t_args, "shuffle_tags"):
|
||||
print(f" Config json is missing 'shuffle_tags'")
|
||||
print(f" Config json is missing 'shuffle_tags' flag")
|
||||
t_args.__dict__["shuffle_tags"] = False
|
||||
if not hasattr(t_args, "save_full_precision"):
|
||||
print(f" Config json is missing 'save_full_precision'")
|
||||
print(f" Config json is missing 'save_full_precision' flag")
|
||||
t_args.__dict__["save_full_precision"] = False
|
||||
if not hasattr(t_args, "notebook"):
|
||||
print(f" Config json is missing 'notebook'")
|
||||
print(f" Config json is missing 'notebook' flag")
|
||||
t_args.__dict__["notebook"] = False
|
||||
if not hasattr(t_args, "disable_unet_training"):
|
||||
print(f" Config json is missing 'disable_unet_training' flag")
|
||||
t_args.__dict__["disable_unet_training"] = False
|
||||
if not hasattr(t_args, "mixed_precision"):
|
||||
print(f" Config json is missing 'mixed_precision' flag")
|
||||
t_args.__dict__["mixed_precision"] = "fp32"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
supported_precisions = ['fp16', 'fp32']
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
@ -881,9 +911,11 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
t_args = argparse.Namespace()
|
||||
t_args.__dict__.update(json.load(f))
|
||||
print(t_args.__dict__)
|
||||
update_old_args(t_args) # update args to support older configs
|
||||
print(t_args.__dict__)
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
print(f"mixed_precision: {args.mixed_precision}")
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
|
@ -894,7 +926,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
||||
|
@ -909,6 +942,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions)
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
|
||||
|
@ -916,6 +951,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
|
@ -923,8 +959,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
||||
|
|
|
@ -16,24 +16,67 @@ limitations under the License.
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from colorama import Fore, Style
|
||||
|
||||
class LogWrapper(object):
|
||||
from tensorboard import SummaryWriter
|
||||
import wandb
|
||||
|
||||
class LogWrapper():
|
||||
"""
|
||||
singleton for logging
|
||||
"""
|
||||
def __init__(self, log_dir, project_name):
|
||||
self.log_dir = log_dir
|
||||
def __init__(self, args, wandb=False):
|
||||
self.logdir = args.logdir
|
||||
self.wandb = wandb
|
||||
|
||||
if wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True)
|
||||
else:
|
||||
self.log_writer = SummaryWriter(log_dir=args.logdir,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
start_time = time.strftime("%Y%m%d-%H%M")
|
||||
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt")
|
||||
log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
self.logger.addHandler(console)
|
||||
|
||||
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False)
|
||||
file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
|
||||
self.logger.addHandler(file)
|
||||
|
||||
def __call__(self):
|
||||
return self.logger
|
||||
def add_image():
|
||||
"""
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
|
||||
if self.wandb:
|
||||
wandb.log({tag: img_tensor}, step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag, img_tensor, global_step)
|
||||
|
||||
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||
"""
|
||||
updates the vram usage for the epoch
|
||||
"""
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
self.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
Loading…
Reference in New Issue