1. Samples format change to make sure global step appear before "ema" indication.

This commit is contained in:
alexds9 2023-09-11 00:13:26 +03:00
parent 39b3082bf4
commit 7259ce873b
2 changed files with 8 additions and 5 deletions

View File

@ -985,10 +985,13 @@ def main(args):
if model_info["is_ema"]: if model_info["is_ema"]:
current_unet, current_text_encoder = unet_ema, text_encoder_ema current_unet, current_text_encoder = unet_ema, text_encoder_ema
extra_info = "ema_" extra_info = "_ema"
else: else:
current_unet, current_text_encoder = unet, text_encoder current_unet, current_text_encoder = unet, text_encoder
torch.cuda.empty_cache()
if model_info["swap_required"]: if model_info["swap_required"]:
with torch.no_grad(): with torch.no_grad():
unet_unloaded = unet.to(ema_device) unet_unloaded = unet.to(ema_device)

View File

@ -269,15 +269,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx] prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt) clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx])) f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result) tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption: if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}", img_tensor=tfimage, global_step=global_step) self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
else: else:
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1 sample_index += 1
del result del result