logging and progress bar improvements

This commit is contained in:
damian 2023-03-02 18:29:28 +01:00
parent 8100e42159
commit 61558be2ae
3 changed files with 20 additions and 10 deletions

View File

@ -105,7 +105,7 @@ class EveryDreamValidator:
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):

View File

@ -691,7 +691,7 @@ def main(args):
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
epoch_times = []
@ -754,7 +754,12 @@ def main(args):
def generate_samples(global_step: int, batch):
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
if prev_sample_steps != sample_generator.sample_steps:
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
@ -787,7 +792,7 @@ def main(args):
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1)
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader):

View File

@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm.auto import tqdm
def clean_filename(filename):
@ -89,7 +90,7 @@ class SampleGenerator:
self.sample_requests = None
self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
@ -169,9 +170,7 @@ class SampleGenerator:
"""
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
disable_progress_bars = not self.show_progress_bars
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
@ -183,10 +182,13 @@ class SampleGenerator:
batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size
for batch in chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test):
#print("batch: ", batch)
batches = list(chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test))
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches:
prompts = [p.prompt for p in batch]
pbar.set_postfix(postfix={'prompts': prompts})
negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
for p in batch]
@ -196,6 +198,8 @@ class SampleGenerator:
batch_images = []
for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps,
@ -257,6 +261,7 @@ class SampleGenerator:
del tfimage
del batch_images
pbar.update(1)
@torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):