initial db loading rewrites

This commit is contained in:
cafeai 2022-12-12 15:40:52 +09:00
parent 27d301c5b9
commit dcc28145df
1 changed files with 8 additions and 6 deletions

View File

@ -274,9 +274,13 @@ class ImageStore:
return len(self.image_files)
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
def entries_iterator(self) -> Generator[Tuple[Tuple[int, int], int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]), f
i = Image.open(self.image_files[f])
width = i.width
height = i.height
del i
yield (width, height), f
# get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Img:
@ -457,8 +461,8 @@ class AspectBucket:
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
def _process_entry(self, entry: Tuple[int, int], index: int) -> bool:
aspect = entry[0] / entry[1] # width / height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
@ -472,8 +476,6 @@ class AspectBucket:
self.bucket_data[bucket].append(index)
del entry
return True
class AspectBucketSampler(torch.utils.data.Sampler):