Merge pull request #29 from noprompt/refactor-follow-up

Refactor follow up
This commit is contained in:
Victor Hall 2023-01-24 09:39:45 -08:00 committed by GitHub
commit 09347779be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 8 deletions

View File

@ -166,7 +166,7 @@ class DataLoaderMultiAspect():
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
self.prepared_train_data = items
self.prepared_train_data = [item for item in items if item.error is None]
random.Random(self.seed).shuffle(self.prepared_train_data)
self.__report_errors(items)

View File

@ -351,6 +351,7 @@ class ImageTrainItem:
return self
def __compute_target_width_height(self):
self.target_wh = None
try:
with Image.open(self.pathname) as image:
width, height = image.size

View File

@ -191,17 +191,13 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555)
:param flip_p: The probability of flipping the image
"""
if os.path.isfile(path) and path.endswith('.json'):
resolver = JSONResolver(aspects, flip_p, seed)
return JSONResolver(aspects, flip_p, seed).image_train_items(path)
if os.path.isdir(path):
resolver = DirectoryResolver(aspects, flip_p, seed)
return DirectoryResolver(aspects, flip_p, seed).image_train_items(path)
if not resolver:
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
items = resolver.image_train_items(path)
return items
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
"""
Resolve the training data from the value.