diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 2171a8d..3a59c5a 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -20,6 +20,7 @@ import time import itertools import numpy as np import json +import re try: pynvml.nvmlInit() @@ -171,7 +172,7 @@ class ImageStore: # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, index: int) -> str: - filename = self.image_files[index].split('.')[0] + '.txt' + filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt' with open(filename, 'r') as f: return f.read()