1. Samples format change to make sure global step appear before "ema" indication.
This commit is contained in:
parent
39b3082bf4
commit
7259ce873b
5
train.py
5
train.py
|
@ -985,10 +985,13 @@ def main(args):
|
||||||
|
|
||||||
if model_info["is_ema"]:
|
if model_info["is_ema"]:
|
||||||
current_unet, current_text_encoder = unet_ema, text_encoder_ema
|
current_unet, current_text_encoder = unet_ema, text_encoder_ema
|
||||||
extra_info = "ema_"
|
extra_info = "_ema"
|
||||||
else:
|
else:
|
||||||
current_unet, current_text_encoder = unet, text_encoder
|
current_unet, current_text_encoder = unet, text_encoder
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
if model_info["swap_required"]:
|
if model_info["swap_required"]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unet_unloaded = unet.to(ema_device)
|
unet_unloaded = unet.to(ema_device)
|
||||||
|
|
|
@ -269,15 +269,15 @@ class SampleGenerator:
|
||||||
prompt = prompts[prompt_idx]
|
prompt = prompts[prompt_idx]
|
||||||
clean_prompt = clean_filename(prompt)
|
clean_prompt = clean_filename(prompt)
|
||||||
|
|
||||||
result.save(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||||
with open(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||||
f.write(str(batch[prompt_idx]))
|
f.write(str(batch[prompt_idx]))
|
||||||
|
|
||||||
tfimage = transforms.ToTensor()(result)
|
tfimage = transforms.ToTensor()(result)
|
||||||
if batch[prompt_idx].wants_random_caption:
|
if batch[prompt_idx].wants_random_caption:
|
||||||
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
|
||||||
else:
|
else:
|
||||||
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||||
sample_index += 1
|
sample_index += 1
|
||||||
|
|
||||||
del result
|
del result
|
||||||
|
|
Loading…
Reference in New Issue