Use DirectoryResolver.recurse_data_root

This commit is contained in:
Joel Holdbrooks 2023-01-22 23:13:05 -08:00
parent aa0a2a1765
commit 08813eabb5
1 changed files with 5 additions and 21 deletions

View File

@ -23,6 +23,8 @@ from PIL import Image
import random
from data.image_train_item import ImageTrainItem, ImageCaption
import data.aspects as aspects
import data.resolver as resolver
from data.resolver import DirectoryResolver
from colorama import Fore, Style
import zipfile
import tqdm
@ -55,7 +57,9 @@ class DataLoaderMultiAspect():
self.unzip_all(data_root)
self.__recurse_data_root(self=self, recurse_root=data_root)
for image_path in DirectoryResolver.recurse_data_root(data_root):
self.image_paths.append(image_path)
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
@ -349,23 +353,3 @@ class DataLoaderMultiAspect():
prepared_train_data.pop(pos)
return picked_images
@staticmethod
def __recurse_data_root(self, recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
self.image_paths.append(current)
sub_dirs = []
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
sub_dirs.append(current)
for dir in sub_dirs:
self.__recurse_data_root(self=self, recurse_root=dir)