diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 0e08746..37447ed 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -274,9 +274,13 @@ class ImageStore: return len(self.image_files) # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Tuple[int, int], int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]), f + i = Image.open(self.image_files[f]) + width = i.width + height = i.height + del i + yield (width, height), f # get image by index def get_image(self, ref: Tuple[int, int, int]) -> Img: @@ -457,8 +461,8 @@ class AspectBucket: self.total_dropped = total_dropped - def _process_entry(self, entry: Image.Image, index: int) -> bool: - aspect = entry.width / entry.height + def _process_entry(self, entry: Tuple[int, int], index: int) -> bool: + aspect = entry[0] / entry[1] # width / height if aspect > self.max_ratio or (1 / aspect) > self.max_ratio: return False @@ -472,8 +476,6 @@ class AspectBucket: self.bucket_data[bucket].append(index) - del entry - return True class AspectBucketSampler(torch.utils.data.Sampler):