diff --git a/train.py b/train.py index 88d5ef5..3e94a29 100644 --- a/train.py +++ b/train.py @@ -985,10 +985,13 @@ def main(args): if model_info["is_ema"]: current_unet, current_text_encoder = unet_ema, text_encoder_ema - extra_info = "ema_" + extra_info = "_ema" else: current_unet, current_text_encoder = unet, text_encoder + torch.cuda.empty_cache() + + if model_info["swap_required"]: with torch.no_grad(): unet_unloaded = unet.to(ema_device) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index ca84b6b..929d3b8 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -269,15 +269,15 @@ class SampleGenerator: prompt = prompts[prompt_idx] clean_prompt = clean_filename(prompt) - result.save(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) - with open(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: + result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) + with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: f.write(str(batch[prompt_idx])) tfimage = transforms.ToTensor()(result) if batch[prompt_idx].wants_random_caption: - self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}", img_tensor=tfimage, global_step=global_step) + self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step) else: - self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) + self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) sample_index += 1 del result