Update sample_generator.py

Changed {sample_index} to {sample_index:03} in file names and logs for sample generator to allow for proper sorting in wandb and other places when there are more than 10 samples as currently wandb sorts as follows: 0, 1, 10, 11, 2, etc. With this change it would be 000, 001, 002, ... , 010, 011, etc.
This commit is contained in:
scottshireman 2024-12-03 08:47:28 -05:00 committed by GitHub
parent 684849ca6e
commit dc756ce22a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -272,15 +272,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx] prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt) clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index:03}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index:03}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx])) f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result) tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption: if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step) self.log_writer.add_image(tag=f"sample_{sample_index:03}{extra_info}", img_tensor=tfimage, global_step=global_step)
else: else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) self.log_writer.add_image(tag=f"sample_{sample_index:03}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1 sample_index += 1
del result del result