logging and progress bar improvements
This commit is contained in:
parent
8100e42159
commit
61558be2ae
|
@ -105,7 +105,7 @@ class EveryDreamValidator:
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
with torch.no_grad(), isolate_rng():
|
with torch.no_grad(), isolate_rng():
|
||||||
loss_validation_epoch = []
|
loss_validation_epoch = []
|
||||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||||
|
|
||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader):
|
||||||
|
|
9
train.py
9
train.py
|
@ -691,7 +691,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||||
|
|
||||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||||
epoch_times = []
|
epoch_times = []
|
||||||
|
|
||||||
|
@ -754,7 +754,12 @@ def main(args):
|
||||||
|
|
||||||
def generate_samples(global_step: int, batch):
|
def generate_samples(global_step: int, batch):
|
||||||
with isolate_rng():
|
with isolate_rng():
|
||||||
|
prev_sample_steps = sample_generator.sample_steps
|
||||||
sample_generator.reload_config()
|
sample_generator.reload_config()
|
||||||
|
if prev_sample_steps != sample_generator.sample_steps:
|
||||||
|
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
|
||||||
|
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||||
|
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||||
sample_generator.update_random_captions(batch["captions"])
|
sample_generator.update_random_captions(batch["captions"])
|
||||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -787,7 +792,7 @@ def main(args):
|
||||||
images_per_sec_log_step = []
|
images_per_sec_log_step = []
|
||||||
|
|
||||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
|
|
@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
def clean_filename(filename):
|
def clean_filename(filename):
|
||||||
|
@ -89,7 +90,7 @@ class SampleGenerator:
|
||||||
|
|
||||||
self.sample_requests = None
|
self.sample_requests = None
|
||||||
self.reload_config()
|
self.reload_config()
|
||||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
|
||||||
if not os.path.exists(f"{log_folder}/samples/"):
|
if not os.path.exists(f"{log_folder}/samples/"):
|
||||||
os.makedirs(f"{log_folder}/samples/")
|
os.makedirs(f"{log_folder}/samples/")
|
||||||
|
|
||||||
|
@ -169,9 +170,7 @@ class SampleGenerator:
|
||||||
"""
|
"""
|
||||||
generates samples at different cfg scales and saves them to disk
|
generates samples at different cfg scales and saves them to disk
|
||||||
"""
|
"""
|
||||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
disable_progress_bars = not self.show_progress_bars
|
||||||
|
|
||||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||||
|
@ -183,10 +182,13 @@ class SampleGenerator:
|
||||||
batch: list[SampleRequest]
|
batch: list[SampleRequest]
|
||||||
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
||||||
return a.size == b.size
|
return a.size == b.size
|
||||||
for batch in chunk_list(self.sample_requests, self.batch_size,
|
batches = list(chunk_list(self.sample_requests, self.batch_size,
|
||||||
compatibility_test=sample_compatibility_test):
|
compatibility_test=sample_compatibility_test))
|
||||||
#print("batch: ", batch)
|
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||||
|
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||||
|
for batch in batches:
|
||||||
prompts = [p.prompt for p in batch]
|
prompts = [p.prompt for p in batch]
|
||||||
|
pbar.set_postfix(postfix={'prompts': prompts})
|
||||||
negative_prompts = [p.negative_prompt for p in batch]
|
negative_prompts = [p.negative_prompt for p in batch]
|
||||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||||
for p in batch]
|
for p in batch]
|
||||||
|
@ -196,6 +198,8 @@ class SampleGenerator:
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
for cfg in self.cfgs:
|
for cfg in self.cfgs:
|
||||||
|
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||||
|
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||||
images = pipe(prompt=prompts,
|
images = pipe(prompt=prompts,
|
||||||
negative_prompt=negative_prompts,
|
negative_prompt=negative_prompts,
|
||||||
num_inference_steps=self.num_inference_steps,
|
num_inference_steps=self.num_inference_steps,
|
||||||
|
@ -257,6 +261,7 @@ class SampleGenerator:
|
||||||
del tfimage
|
del tfimage
|
||||||
del batch_images
|
del batch_images
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
||||||
|
|
Loading…
Reference in New Issue