logging and progress bar improvements

This commit is contained in:
damian 2023-03-02 18:29:28 +01:00
parent 8100e42159
commit 61558be2ae
3 changed files with 20 additions and 10 deletions

View File

@ -105,7 +105,7 @@ class EveryDreamValidator:
[Any, Any], tuple[torch.Tensor, torch.Tensor]]): [Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng(): with torch.no_grad(), isolate_rng():
loss_validation_epoch = [] loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1) steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}") steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):

View File

@ -691,7 +691,7 @@ def main(args):
) )
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
epoch_times = [] epoch_times = []
@ -754,7 +754,12 @@ def main(args):
def generate_samples(global_step: int, batch): def generate_samples(global_step: int, batch):
with isolate_rng(): with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config() sample_generator.reload_config()
if prev_sample_steps != sample_generator.sample_steps:
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"]) sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet, inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -787,7 +792,7 @@ def main(args):
images_per_sec_log_step = [] images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size) epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1) steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):

View File

@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm
def clean_filename(filename): def clean_filename(filename):
@ -89,7 +90,7 @@ class SampleGenerator:
self.sample_requests = None self.sample_requests = None
self.reload_config() self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
if not os.path.exists(f"{log_folder}/samples/"): if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/") os.makedirs(f"{log_folder}/samples/")
@ -169,9 +170,7 @@ class SampleGenerator:
""" """
generates samples at different cfg scales and saves them to disk generates samples at different cfg scales and saves them to disk
""" """
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") disable_progress_bars = not self.show_progress_bars
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
try: try:
font = ImageFont.truetype(font="arial.ttf", size=20) font = ImageFont.truetype(font="arial.ttf", size=20)
@ -183,10 +182,13 @@ class SampleGenerator:
batch: list[SampleRequest] batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size return a.size == b.size
for batch in chunk_list(self.sample_requests, self.batch_size, batches = list(chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test): compatibility_test=sample_compatibility_test))
#print("batch: ", batch) pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches:
prompts = [p.prompt for p in batch] prompts = [p.prompt for p in batch]
pbar.set_postfix(postfix={'prompts': prompts})
negative_prompts = [p.negative_prompt for p in batch] negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
for p in batch] for p in batch]
@ -196,6 +198,8 @@ class SampleGenerator:
batch_images = [] batch_images = []
for cfg in self.cfgs: for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts, images = pipe(prompt=prompts,
negative_prompt=negative_prompts, negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
@ -257,6 +261,7 @@ class SampleGenerator:
del tfimage del tfimage
del batch_images del batch_images
pbar.update(1)
@torch.no_grad() @torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict): def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):