Get rid of on_event callback
This commit is contained in:
parent
9c6df69e4e
commit
4e6c5f4d00
104
data/resolver.py
104
data/resolver.py
|
@ -25,10 +25,10 @@ class UndersizedImageEvent(Event):
|
|||
self.target_size = target_size
|
||||
|
||||
class DataResolver:
|
||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, on_event: OptionalCallable=None):
|
||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0):
|
||||
self.aspects = aspects
|
||||
self.flip_p = flip_p
|
||||
self.on_event = on_event or (lambda data: None)
|
||||
self.events = []
|
||||
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -48,17 +48,25 @@ class DataResolver:
|
|||
|
||||
if width * height < target_wh[0] * target_wh[1]:
|
||||
event = UndersizedImageEvent(image_path, (width, height), target_wh)
|
||||
self.on_event(event)
|
||||
self.events.append(event)
|
||||
|
||||
return target_wh
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption) -> ImageTrainItem:
|
||||
#try:
|
||||
target_wh = self.compute_target_width_height(image_path)
|
||||
return ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=image_path, flip_p=self.flip_p)
|
||||
# except Exception as e:
|
||||
# logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
# logging.error(f" *** exception: {e}")
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
try:
|
||||
target_wh = self.compute_target_width_height(image_path)
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
target_wh=target_wh,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
# TODO: This should only handle Image errors.
|
||||
except Exception as e:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {e}")
|
||||
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
|
@ -131,10 +139,30 @@ class DirectoryResolver(DataResolver):
|
|||
DirectoryResolver.unzip_all(data_root)
|
||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
skip_folders = []
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption = ImageCaption.from_file(pathname)
|
||||
item = self.image_train_item(pathname, caption)
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
if os.path.exists(multiply_txt_path):
|
||||
try:
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
else:
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
|
||||
if item:
|
||||
items.append(item)
|
||||
|
@ -182,68 +210,48 @@ def strategy(data_root: str):
|
|||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||
|
||||
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||
"""
|
||||
:param data_root: Directory or JSON file.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
"""
|
||||
if os.path.isfile(path) and path.endswith('.json'):
|
||||
resolver = JSONResolver(aspects, flip_p, on_event)
|
||||
resolver = JSONResolver(aspects, flip_p)
|
||||
|
||||
if os.path.isdir(path):
|
||||
resolver = DirectoryResolver(aspects, flip_p, on_event)
|
||||
resolver = DirectoryResolver(aspects, flip_p)
|
||||
|
||||
if not resolver:
|
||||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
||||
|
||||
return resolver.image_train_items(path)
|
||||
items = resolver.image_train_items(path)
|
||||
events = resolver.events
|
||||
return items, events
|
||||
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
|
||||
"""
|
||||
Resolve the training data from the value.
|
||||
:param value: The value to resolve, either a dict or a string.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
:param on_event: The callback to call when an event occurs (e.g. undersized image detected)
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return resolve_root(value, aspects, flip_p, on_event)
|
||||
return resolve_root(value, aspects, flip_p)
|
||||
|
||||
if isinstance(value, dict):
|
||||
resolver = value.get('resolver', None)
|
||||
match resolver:
|
||||
case 'directory' | 'json':
|
||||
path = value.get('path', None)
|
||||
return resolve_root(path, aspects, flip_p, on_event)
|
||||
return resolve_root(path, aspects, flip_p)
|
||||
case 'multi':
|
||||
items = []
|
||||
resolved_items = []
|
||||
resolved_events = []
|
||||
for resolver in value.get('resolvers', []):
|
||||
items += resolve(resolver, aspects, flip_p, on_event)
|
||||
return items
|
||||
items, events = resolve(resolver, aspects, flip_p)
|
||||
resolved_items.extend(items)
|
||||
resolved_events.extend(events)
|
||||
return resolved_items, resolved_events
|
||||
case _:
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
||||
|
||||
|
||||
# example = {
|
||||
# 'resolver': 'directory',
|
||||
# 'data_root': 'data',
|
||||
# }
|
||||
|
||||
# example = {
|
||||
# 'resolver': 'json',
|
||||
# 'data_root': 'data.json',
|
||||
# }
|
||||
|
||||
# example = {
|
||||
# 'resolver': 'multi',
|
||||
# 'resolvers': [
|
||||
# {
|
||||
# 'resolver': 'directory',
|
||||
# 'data_root': 'data',
|
||||
# }, {
|
||||
# 'resolver': 'json',
|
||||
# 'data_root': 'data.json',
|
||||
# },
|
||||
# ]
|
||||
# }
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
|
@ -11,7 +11,6 @@ import data.resolver as resolver
|
|||
DATA_PATH = os.path.abspath('./test/data')
|
||||
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
|
||||
ASPECTS = aspects.get_aspect_buckets(512)
|
||||
FLIP_P = 0.0
|
||||
|
||||
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
|
||||
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
|
||||
|
@ -46,32 +45,23 @@ class TestResolve(unittest.TestCase):
|
|||
with open(JSON_ROOT_PATH, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
for file in glob.glob(os.path.join(DATA_PATH, 'test*')):
|
||||
os.remove(file)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.events = []
|
||||
self.on_event = lambda event: self.events.append(event.name)
|
||||
return super().setUp()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.events = []
|
||||
self.on_event = None
|
||||
return super().tearDown()
|
||||
|
||||
def test_directory_resolve_with_str(self):
|
||||
image_train_items = resolver.resolve(DATA_PATH, ASPECTS, FLIP_P, self.on_event)
|
||||
image_paths = [item.pathname for item in image_train_items]
|
||||
image_captions = [item.caption for item in image_train_items]
|
||||
items, events = resolver.resolve(DATA_PATH, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(image_train_items), 3)
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
self.assertEqual(self.events, ['undersized_image'])
|
||||
|
||||
events = list(map(lambda e: e.name, events))
|
||||
self.assertEqual(events, ['undersized_image'])
|
||||
|
||||
def test_directory_resolve_with_dict(self):
|
||||
data_root_spec = {
|
||||
|
@ -79,26 +69,30 @@ class TestResolve(unittest.TestCase):
|
|||
'path': DATA_PATH,
|
||||
}
|
||||
|
||||
image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event)
|
||||
image_paths = [item.pathname for item in image_train_items]
|
||||
image_captions = [item.caption for item in image_train_items]
|
||||
items, events = resolver.resolve(data_root_spec, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(image_train_items), 3)
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
self.assertEqual(self.events, ['undersized_image'])
|
||||
|
||||
events = list(map(lambda e: e.name, events))
|
||||
self.assertEqual(events, ['undersized_image'])
|
||||
|
||||
def test_json_resolve_with_str(self):
|
||||
image_train_items = resolver.resolve(JSON_ROOT_PATH, ASPECTS, FLIP_P, self.on_event)
|
||||
image_paths = [item.pathname for item in image_train_items]
|
||||
image_captions = [item.caption for item in image_train_items]
|
||||
items, events = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(image_train_items), 3)
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
|
||||
self.assertEqual(self.events, ['undersized_image'])
|
||||
|
||||
events = list(map(lambda e: e.name, events))
|
||||
self.assertEqual(events, ['undersized_image'])
|
||||
|
||||
def test_json_resolve_with_dict(self):
|
||||
data_root_spec = {
|
||||
|
@ -106,12 +100,14 @@ class TestResolve(unittest.TestCase):
|
|||
'path': JSON_ROOT_PATH,
|
||||
}
|
||||
|
||||
image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event)
|
||||
image_paths = [item.pathname for item in image_train_items]
|
||||
image_captions = [item.caption for item in image_train_items]
|
||||
items, events = resolver.resolve(data_root_spec, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(image_train_items), 3)
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
|
||||
self.assertEqual(self.events, ['undersized_image'])
|
||||
|
||||
events = list(map(lambda e: e.name, events))
|
||||
self.assertEqual(events, ['undersized_image'])
|
Loading…
Reference in New Issue