diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 95f5afd..2793ae5 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -105,7 +105,7 @@ class EveryDreamValidator: [Any, Any], tuple[torch.Tensor, torch.Tensor]]): with torch.no_grad(), isolate_rng(): loss_validation_epoch = [] - steps_pbar = tqdm(range(len(dataloader)), position=1) + steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}") for step, batch in enumerate(dataloader): diff --git a/train.py b/train.py index 6c5d6b6..d5eeff0 100644 --- a/train.py +++ b/train.py @@ -691,7 +691,7 @@ def main(args): ) logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") - epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) + epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_times = [] @@ -754,7 +754,12 @@ def main(args): def generate_samples(global_step: int, batch): with isolate_rng(): + prev_sample_steps = sample_generator.sample_steps sample_generator.reload_config() + if prev_sample_steps != sample_generator.sample_steps: + next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps + print(f" * SampleGenerator config changed, now generating images samples every " + + f"{sample_generator.sample_steps} training steps (next={next_sample_step})") sample_generator.update_random_captions(batch["captions"]) inference_pipe = sample_generator.create_inference_pipe(unet=unet, text_encoder=text_encoder, @@ -787,7 +792,7 @@ def main(args): images_per_sec_log_step = [] epoch_len = math.ceil(len(train_batch) / args.batch_size) - steps_pbar = tqdm(range(epoch_len), position=1) + steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") for step, batch in enumerate(train_dataloader): diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 049a628..c9c6ee6 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep from torch.cuda.amp import autocast from torch.utils.tensorboard import SummaryWriter from torchvision import transforms +from tqdm.auto import tqdm def clean_filename(filename): @@ -89,7 +90,7 @@ class SampleGenerator: self.sample_requests = None self.reload_config() - print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") + print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps") if not os.path.exists(f"{log_folder}/samples/"): os.makedirs(f"{log_folder}/samples/") @@ -169,9 +170,7 @@ class SampleGenerator: """ generates samples at different cfg scales and saves them to disk """ - logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") - - pipe.set_progress_bar_config(disable=(not self.show_progress_bars)) + disable_progress_bars = not self.show_progress_bars try: font = ImageFont.truetype(font="arial.ttf", size=20) @@ -183,10 +182,13 @@ class SampleGenerator: batch: list[SampleRequest] def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: return a.size == b.size - for batch in chunk_list(self.sample_requests, self.batch_size, - compatibility_test=sample_compatibility_test): - #print("batch: ", batch) + batches = list(chunk_list(self.sample_requests, self.batch_size, + compatibility_test=sample_compatibility_test)) + pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False, + desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") + for batch in batches: prompts = [p.prompt for p in batch] + pbar.set_postfix(postfix={'prompts': prompts}) negative_prompts = [p.negative_prompt for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] @@ -196,6 +198,8 @@ class SampleGenerator: batch_images = [] for cfg in self.cfgs: + pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False, + desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}") images = pipe(prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=self.num_inference_steps, @@ -257,6 +261,7 @@ class SampleGenerator: del tfimage del batch_images + pbar.update(1) @torch.no_grad() def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):