Move target_wh calculation to ImageTrainItem
This commit is contained in:
parent
1dfda8d6d4
commit
1a0b7994f4
|
@ -161,27 +161,34 @@ class DataLoaderMultiAspect():
|
|||
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
items, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
||||
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
||||
image_paths = set(map(lambda item: item.pathname, items))
|
||||
|
||||
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
self.prepared_train_data = items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.__report_undersized_images(events)
|
||||
self.__report_errors(items)
|
||||
|
||||
def __report_undersized_images(self, events: list[resolver.Event]):
|
||||
events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)]
|
||||
def __report_errors(self, items: list[ImageTrainItem]):
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
||||
if len(events) > 0:
|
||||
if len(undersized_items) > 0:
|
||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for event in events:
|
||||
for event in undersized_items:
|
||||
message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
|
|
@ -22,6 +22,7 @@ import typing
|
|||
import yaml
|
||||
|
||||
import PIL
|
||||
import PIL.Image as Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
|
@ -256,9 +257,9 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.cropped_img = None
|
||||
|
@ -269,6 +270,9 @@ class ImageTrainItem:
|
|||
self.image = []
|
||||
else:
|
||||
self.image = image
|
||||
|
||||
self.error = None
|
||||
self.__compute_target_width_height()
|
||||
|
||||
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
||||
"""
|
||||
|
@ -345,6 +349,18 @@ class ImageTrainItem:
|
|||
# print(self.image.shape)
|
||||
|
||||
return self
|
||||
|
||||
def __compute_target_width_height(self):
|
||||
try:
|
||||
with Image.open(self.image_path) as image:
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
self.is_undersized = width * height < target_wh[0] * target_wh[1]
|
||||
self.target_wh = target_wh
|
||||
except Exception as e:
|
||||
self.error = e
|
||||
|
||||
@staticmethod
|
||||
def __autocrop(image: PIL.Image, q=.404):
|
||||
|
|
|
@ -11,22 +11,10 @@ from colorama import Fore, Style
|
|||
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
|
||||
class Event:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
class UndersizedImageEvent(Event):
|
||||
def __init__(self, image_path: str, image_size: typing.Tuple[int, int], target_size: typing.Tuple[int, int]):
|
||||
super().__init__('undersized_image')
|
||||
self.image_path = image_path
|
||||
self.image_size = image_size
|
||||
self.target_size = target_size
|
||||
|
||||
class DataResolver:
|
||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
|
||||
self.aspects = aspects
|
||||
self.flip_p = flip_p
|
||||
self.events = []
|
||||
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -37,35 +25,15 @@ class DataResolver:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def compute_target_width_height(self, image_path: str) -> typing.Optional[typing.Tuple[int, int]]:
|
||||
# Compute the target width and height for the image based on the aspect ratio.
|
||||
with Image.open(image_path) as image:
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
if width * height < target_wh[0] * target_wh[1]:
|
||||
event = UndersizedImageEvent(image_path, (width, height), target_wh)
|
||||
self.events.append(event)
|
||||
|
||||
return target_wh
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
try:
|
||||
target_wh = self.compute_target_width_height(image_path)
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
target_wh=target_wh,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
# TODO: This should only handle Image errors.
|
||||
except Exception as e:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {e}")
|
||||
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=self.aspects,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||
|
@ -219,7 +187,7 @@ def strategy(data_root: str):
|
|||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||
|
||||
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
|
||||
"""
|
||||
:param data_root: Directory or JSON file.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
|
@ -235,10 +203,9 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) ->
|
|||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
||||
|
||||
items = resolver.image_train_items(path)
|
||||
events = resolver.events
|
||||
return items, events
|
||||
return items
|
||||
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Resolve the training data from the value.
|
||||
:param value: The value to resolve, either a dict or a string.
|
||||
|
@ -255,12 +222,9 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=
|
|||
path = value.get('path', None)
|
||||
return resolve_root(path, aspects, flip_p, seed)
|
||||
case 'multi':
|
||||
resolved_items = []
|
||||
resolved_events = []
|
||||
items = []
|
||||
for resolver in value.get('resolvers', []):
|
||||
items, events = resolve(resolver, aspects, flip_p, seed)
|
||||
resolved_items.extend(items)
|
||||
resolved_events.extend(events)
|
||||
return resolved_items, resolved_events
|
||||
items.extend(resolve(resolver, aspects, flip_p, seed))
|
||||
return items
|
||||
case _:
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
Loading…
Reference in New Issue