Use DirectoryResolver.recurse_data_root
This commit is contained in:
parent
aa0a2a1765
commit
08813eabb5
|
@ -23,6 +23,8 @@ from PIL import Image
|
|||
import random
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
from data.resolver import DirectoryResolver
|
||||
from colorama import Fore, Style
|
||||
import zipfile
|
||||
import tqdm
|
||||
|
@ -55,7 +57,9 @@ class DataLoaderMultiAspect():
|
|||
|
||||
self.unzip_all(data_root)
|
||||
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
for image_path in DirectoryResolver.recurse_data_root(data_root):
|
||||
self.image_paths.append(image_path)
|
||||
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
@ -349,23 +353,3 @@ class DataLoaderMultiAspect():
|
|||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
self.image_paths.append(current)
|
||||
|
||||
sub_dirs = []
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
sub_dirs.append(current)
|
||||
|
||||
for dir in sub_dirs:
|
||||
self.__recurse_data_root(self=self, recurse_root=dir)
|
||||
|
|
Loading…
Reference in New Issue