107 lines
4.0 KiB
Python
107 lines
4.0 KiB
Python
import argparse
|
|
import importlib
|
|
import logging
|
|
import time
|
|
import warnings
|
|
from PIL import Image
|
|
|
|
class BasePlugin:
|
|
def on_epoch_start(self, **kwargs):
|
|
pass
|
|
def on_epoch_end(self, **kwargs):
|
|
pass
|
|
def on_training_start(self, **kwargs):
|
|
pass
|
|
def on_training_end(self, **kwargs):
|
|
pass
|
|
def on_step_start(self, **kwargs):
|
|
pass
|
|
def on_step_end(self, **kwargs):
|
|
pass
|
|
def transform_caption(self, caption:str):
|
|
return caption
|
|
def transform_pil_image(self, img:Image):
|
|
return img
|
|
|
|
def load_plugin(plugin_path):
|
|
print(f" - Attempting to load plugin: {plugin_path}")
|
|
module_path = '.'.join(plugin_path.split('.')[:-1])
|
|
module = importlib.import_module(module_path)
|
|
plugin_name = plugin_path.split('.')[-1]
|
|
|
|
plugin_class = getattr(module, plugin_name)
|
|
|
|
if not issubclass(plugin_class, BasePlugin):
|
|
raise TypeError(f'{plugin_path} is not a valid plugin')
|
|
logging.info(f" - Plugin {plugin_path} loaded to {plugin_class}")
|
|
return plugin_class()
|
|
|
|
class Timer:
|
|
def __init__(self, warn_seconds, label='plugin'):
|
|
self.warn_seconds = warn_seconds
|
|
self.label = label
|
|
|
|
def __enter__(self):
|
|
self.start = time.time()
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
elapsed_time = time.time() - self.start
|
|
if elapsed_time > self.warn_seconds:
|
|
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.warn_seconds} seconds')
|
|
|
|
|
|
class PluginRunner:
|
|
def __init__(self, plugins:list=[], epoch_warn_seconds=5, step_warn_seconds=0.5, training_warn_seconds=20):
|
|
"""
|
|
plugins: list of plugins to run
|
|
epoch_warn_seconds: warn if any epoch start/end call takes longer than this
|
|
step_warn_seconds: warn if any step start/end call takes longer than this
|
|
training_warn_seconds: warn if any training start/end call take longer than this
|
|
"""
|
|
self.plugins = plugins
|
|
self.epoch_warn_seconds = epoch_warn_seconds
|
|
self.step_warn_seconds = step_warn_seconds
|
|
self.training_warn_seconds = training_warn_seconds
|
|
|
|
def run_on_epoch_end(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_epoch_end(**kwargs)
|
|
|
|
def run_on_epoch_start(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_epoch_start(**kwargs)
|
|
|
|
def run_on_training_start(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_training_start(**kwargs)
|
|
|
|
def run_on_training_end(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_training_end(**kwargs)
|
|
|
|
def run_on_step_start(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_step_start(**kwargs)
|
|
|
|
def run_on_step_end(self, **kwargs):
|
|
for plugin in self.plugins:
|
|
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
|
plugin.on_step_end(**kwargs)
|
|
|
|
def run_transform_caption(self, caption):
|
|
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"):
|
|
for plugin in self.plugins:
|
|
caption = plugin.transform_caption(caption)
|
|
return caption
|
|
|
|
def run_transform_pil_image(self, img):
|
|
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_pil_image"):
|
|
for plugin in self.plugins:
|
|
img = plugin.transform_pil_image(img)
|
|
return img
|