Get rid of on_event callback

This commit is contained in:
Joel Holdbrooks 2023-01-22 23:58:25 -08:00
parent 9c6df69e4e
commit 4e6c5f4d00
2 changed files with 84 additions and 80 deletions

View File

@ -25,10 +25,10 @@ class UndersizedImageEvent(Event):
self.target_size = target_size
class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, on_event: OptionalCallable=None):
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0):
self.aspects = aspects
self.flip_p = flip_p
self.on_event = on_event or (lambda data: None)
self.events = []
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
@ -48,17 +48,25 @@ class DataResolver:
if width * height < target_wh[0] * target_wh[1]:
event = UndersizedImageEvent(image_path, (width, height), target_wh)
self.on_event(event)
self.events.append(event)
return target_wh
def image_train_item(self, image_path: str, caption: ImageCaption) -> ImageTrainItem:
#try:
target_wh = self.compute_target_width_height(image_path)
return ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=image_path, flip_p=self.flip_p)
# except Exception as e:
# logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
# logging.error(f" *** exception: {e}")
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
try:
target_wh = self.compute_target_width_height(image_path)
return ImageTrainItem(
image=None,
caption=caption,
target_wh=target_wh,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
# TODO: This should only handle Image errors.
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
class JSONResolver(DataResolver):
@ -131,10 +139,30 @@ class DirectoryResolver(DataResolver):
DirectoryResolver.unzip_all(data_root)
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = []
multipliers = {}
skip_folders = []
for pathname in tqdm.tqdm(image_paths):
caption = ImageCaption.from_file(pathname)
item = self.image_train_item(pathname, caption)
current_dir = os.path.dirname(pathname)
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
if os.path.exists(multiply_txt_path):
try:
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
else:
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
if item:
items.append(item)
@ -182,68 +210,48 @@ def strategy(data_root: str):
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
"""
:param data_root: Directory or JSON file.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
"""
if os.path.isfile(path) and path.endswith('.json'):
resolver = JSONResolver(aspects, flip_p, on_event)
resolver = JSONResolver(aspects, flip_p)
if os.path.isdir(path):
resolver = DirectoryResolver(aspects, flip_p, on_event)
resolver = DirectoryResolver(aspects, flip_p)
if not resolver:
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
return resolver.image_train_items(path)
items = resolver.image_train_items(path)
events = resolver.events
return items, events
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
"""
Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
:param on_event: The callback to call when an event occurs (e.g. undersized image detected)
"""
if isinstance(value, str):
return resolve_root(value, aspects, flip_p, on_event)
return resolve_root(value, aspects, flip_p)
if isinstance(value, dict):
resolver = value.get('resolver', None)
match resolver:
case 'directory' | 'json':
path = value.get('path', None)
return resolve_root(path, aspects, flip_p, on_event)
return resolve_root(path, aspects, flip_p)
case 'multi':
items = []
resolved_items = []
resolved_events = []
for resolver in value.get('resolvers', []):
items += resolve(resolver, aspects, flip_p, on_event)
return items
items, events = resolve(resolver, aspects, flip_p)
resolved_items.extend(items)
resolved_events.extend(events)
return resolved_items, resolved_events
case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
# example = {
# 'resolver': 'directory',
# 'data_root': 'data',
# }
# example = {
# 'resolver': 'json',
# 'data_root': 'data.json',
# }
# example = {
# 'resolver': 'multi',
# 'resolvers': [
# {
# 'resolver': 'directory',
# 'data_root': 'data',
# }, {
# 'resolver': 'json',
# 'data_root': 'data.json',
# },
# ]
# }
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")

View File

@ -11,7 +11,6 @@ import data.resolver as resolver
DATA_PATH = os.path.abspath('./test/data')
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
ASPECTS = aspects.get_aspect_buckets(512)
FLIP_P = 0.0
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
@ -46,32 +45,23 @@ class TestResolve(unittest.TestCase):
with open(JSON_ROOT_PATH, 'w') as f:
json.dump(json_data, f, indent=4)
@classmethod
def tearDownClass(cls):
for file in glob.glob(os.path.join(DATA_PATH, 'test*')):
os.remove(file)
def setUp(self) -> None:
self.events = []
self.on_event = lambda event: self.events.append(event.name)
return super().setUp()
def tearDown(self) -> None:
self.events = []
self.on_event = None
return super().tearDown()
def test_directory_resolve_with_str(self):
image_train_items = resolver.resolve(DATA_PATH, ASPECTS, FLIP_P, self.on_event)
image_paths = [item.pathname for item in image_train_items]
image_captions = [item.caption for item in image_train_items]
items, events = resolver.resolve(DATA_PATH, ASPECTS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(image_train_items), 3)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
self.assertEqual(self.events, ['undersized_image'])
events = list(map(lambda e: e.name, events))
self.assertEqual(events, ['undersized_image'])
def test_directory_resolve_with_dict(self):
data_root_spec = {
@ -79,26 +69,30 @@ class TestResolve(unittest.TestCase):
'path': DATA_PATH,
}
image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event)
image_paths = [item.pathname for item in image_train_items]
image_captions = [item.caption for item in image_train_items]
items, events = resolver.resolve(data_root_spec, ASPECTS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(image_train_items), 3)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
self.assertEqual(self.events, ['undersized_image'])
events = list(map(lambda e: e.name, events))
self.assertEqual(events, ['undersized_image'])
def test_json_resolve_with_str(self):
image_train_items = resolver.resolve(JSON_ROOT_PATH, ASPECTS, FLIP_P, self.on_event)
image_paths = [item.pathname for item in image_train_items]
image_captions = [item.caption for item in image_train_items]
items, events = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(image_train_items), 3)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
self.assertEqual(self.events, ['undersized_image'])
events = list(map(lambda e: e.name, events))
self.assertEqual(events, ['undersized_image'])
def test_json_resolve_with_dict(self):
data_root_spec = {
@ -106,12 +100,14 @@ class TestResolve(unittest.TestCase):
'path': JSON_ROOT_PATH,
}
image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event)
image_paths = [item.pathname for item in image_train_items]
image_captions = [item.caption for item in image_train_items]
items, events = resolver.resolve(data_root_spec, ASPECTS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(image_train_items), 3)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
self.assertEqual(self.events, ['undersized_image'])
events = list(map(lambda e: e.name, events))
self.assertEqual(events, ['undersized_image'])