Move target_wh calculation to ImageTrainItem

This commit is contained in:
Joel Holdbrooks 2023-01-23 12:00:42 -08:00
parent 1dfda8d6d4
commit 1a0b7994f4
3 changed files with 45 additions and 58 deletions

View File

@ -161,27 +161,34 @@ class DataLoaderMultiAspect():
logging.info(" Preloading images...")
items, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
image_paths = set(map(lambda item: item.pathname, items))
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
self.prepared_train_data = items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.__report_undersized_images(events)
self.__report_errors(items)
def __report_undersized_images(self, events: list[resolver.Event]):
events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)]
def __report_errors(self, items: list[ImageTrainItem]):
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {item.error}")
undersized_items = [item for item in items if item.is_undersized]
if len(events) > 0:
if len(undersized_items) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for event in events:
for event in undersized_items:
message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images"
undersized_images_file.write(message)
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""

View File

@ -22,6 +22,7 @@ import typing
import yaml
import PIL
import PIL.Image as Image
import numpy as np
from torchvision import transforms
@ -256,9 +257,9 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
self.caption = caption
self.target_wh = target_wh
self.aspects = aspects
self.pathname = pathname
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.cropped_img = None
@ -269,6 +270,9 @@ class ImageTrainItem:
self.image = []
else:
self.image = image
self.error = None
self.__compute_target_width_height()
def hydrate(self, crop=False, save=False, crop_jitter=20):
"""
@ -345,6 +349,18 @@ class ImageTrainItem:
# print(self.image.shape)
return self
def __compute_target_width_height(self):
try:
with Image.open(self.image_path) as image:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
self.is_undersized = width * height < target_wh[0] * target_wh[1]
self.target_wh = target_wh
except Exception as e:
self.error = e
@staticmethod
def __autocrop(image: PIL.Image, q=.404):

View File

@ -11,22 +11,10 @@ from colorama import Fore, Style
from data.image_train_item import ImageCaption, ImageTrainItem
class Event:
def __init__(self, name: str):
self.name = name
class UndersizedImageEvent(Event):
def __init__(self, image_path: str, image_size: typing.Tuple[int, int], target_size: typing.Tuple[int, int]):
super().__init__('undersized_image')
self.image_path = image_path
self.image_size = image_size
self.target_size = target_size
class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
self.aspects = aspects
self.flip_p = flip_p
self.events = []
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
@ -37,35 +25,15 @@ class DataResolver:
"""
raise NotImplementedError()
def compute_target_width_height(self, image_path: str) -> typing.Optional[typing.Tuple[int, int]]:
# Compute the target width and height for the image based on the aspect ratio.
with Image.open(image_path) as image:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
if width * height < target_wh[0] * target_wh[1]:
event = UndersizedImageEvent(image_path, (width, height), target_wh)
self.events.append(event)
return target_wh
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
try:
target_wh = self.compute_target_width_height(image_path)
return ImageTrainItem(
image=None,
caption=caption,
target_wh=target_wh,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
# TODO: This should only handle Image errors.
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
return ImageTrainItem(
image=None,
caption=caption,
aspects=self.aspects,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
@ -219,7 +187,7 @@ def strategy(data_root: str):
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
"""
:param data_root: Directory or JSON file.
:param aspects: The list of aspect ratios to use
@ -235,10 +203,9 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) ->
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
items = resolver.image_train_items(path)
events = resolver.events
return items, events
return items
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
"""
Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string.
@ -255,12 +222,9 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=
path = value.get('path', None)
return resolve_root(path, aspects, flip_p, seed)
case 'multi':
resolved_items = []
resolved_events = []
items = []
for resolver in value.get('resolvers', []):
items, events = resolve(resolver, aspects, flip_p, seed)
resolved_items.extend(items)
resolved_events.extend(events)
return resolved_items, resolved_events
items.extend(resolve(resolver, aspects, flip_p, seed))
return items
case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")