logging and progress bar improvements
This commit is contained in:
parent
8100e42159
commit
61558be2ae
|
@ -105,7 +105,7 @@ class EveryDreamValidator:
|
|||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
with torch.no_grad(), isolate_rng():
|
||||
loss_validation_epoch = []
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
|
|
9
train.py
9
train.py
|
@ -691,7 +691,7 @@ def main(args):
|
|||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
epoch_times = []
|
||||
|
||||
|
@ -754,7 +754,12 @@ def main(args):
|
|||
|
||||
def generate_samples(global_step: int, batch):
|
||||
with isolate_rng():
|
||||
prev_sample_steps = sample_generator.sample_steps
|
||||
sample_generator.reload_config()
|
||||
if prev_sample_steps != sample_generator.sample_steps:
|
||||
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
|
||||
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -787,7 +792,7 @@ def main(args):
|
|||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
|
|
@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
|
|||
from torch.cuda.amp import autocast
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def clean_filename(filename):
|
||||
|
@ -89,7 +90,7 @@ class SampleGenerator:
|
|||
|
||||
self.sample_requests = None
|
||||
self.reload_config()
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
|
||||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
|
@ -169,9 +170,7 @@ class SampleGenerator:
|
|||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
||||
|
||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
||||
disable_progress_bars = not self.show_progress_bars
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||
|
@ -183,10 +182,13 @@ class SampleGenerator:
|
|||
batch: list[SampleRequest]
|
||||
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
||||
return a.size == b.size
|
||||
for batch in chunk_list(self.sample_requests, self.batch_size,
|
||||
compatibility_test=sample_compatibility_test):
|
||||
#print("batch: ", batch)
|
||||
batches = list(chunk_list(self.sample_requests, self.batch_size,
|
||||
compatibility_test=sample_compatibility_test))
|
||||
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||
for batch in batches:
|
||||
prompts = [p.prompt for p in batch]
|
||||
pbar.set_postfix(postfix={'prompts': prompts})
|
||||
negative_prompts = [p.negative_prompt for p in batch]
|
||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||
for p in batch]
|
||||
|
@ -196,6 +198,8 @@ class SampleGenerator:
|
|||
|
||||
batch_images = []
|
||||
for cfg in self.cfgs:
|
||||
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||
images = pipe(prompt=prompts,
|
||||
negative_prompt=negative_prompts,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
|
@ -257,6 +261,7 @@ class SampleGenerator:
|
|||
del tfimage
|
||||
del batch_images
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
||||
|
|
Loading…
Reference in New Issue