improve plugins

This commit is contained in:
Victor Hall 2023-07-04 17:29:22 -04:00
parent 6907c01b51
commit 42c417171d
3 changed files with 104 additions and 15 deletions

0
doc/PLUGINS.md Normal file
View File

View File

@ -1,23 +1,91 @@
import argparse
import importlib
import logging
import time
import warnings
class BasePlugin:
def on_epoch_start(self, **kwargs):
pass
def on_epoch_end(self, **kwargs):
pass
def on_training_start(self, **kwargs):
pass
def on_training_end(self, **kwargs):
pass
def on_step_start(self, **kwargs):
pass
def on_step_end(self, **kwargs):
pass
class ExampleLoggingPlugin(BasePlugin):
def on_epoch_start(self, **kwargs):
logging.info(f"Epoch {kwargs['epoch']} starting")
def on_epoch_end(self, **kwargs):
logging.info(f"Epoch {kwargs['epoch']} finished")
def load_plugin(plugin_name):
module = importlib.import_module(plugin_name)
def load_plugin(plugin_path):
print(f" - Attempting to load plugin: {plugin_path}")
module_path = '.'.join(plugin_path.split('.')[:-1])
module = importlib.import_module(module_path)
plugin_name = plugin_path.split('.')[-1]
plugin_class = getattr(module, plugin_name)
if not issubclass(plugin_class, BasePlugin):
raise TypeError(f'{plugin_name} is not a valid plugin')
logging.info(f"Plugin {plugin_name} loaded")
raise TypeError(f'{plugin_path} is not a valid plugin')
logging.info(f" - Plugin {plugin_path} loaded to {plugin_class}")
return plugin_class()
class Timer:
def __init__(self, warn_seconds, label='plugin'):
self.warn_seconds = warn_seconds
self.label = label
def __enter__(self):
self.start = time.time()
def __exit__(self, type, value, traceback):
elapsed_time = time.time() - self.start
if elapsed_time > self.warn_seconds:
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds')
class PluginRunner:
def __init__(self, plugins: list, epoch_warn_seconds=5, step_warn_seconds=0.5, training_warn_seconds=20):
"""
plugins: list of plugins to run
epoch_warn_seconds: warn if any epoch start/end call takes longer than this
step_warn_seconds: warn if any step start/end call takes longer than this
training_warn_seconds: warn if any training start/end call take longer than this
"""
self.plugins = plugins
self.epoch_warn_seconds = epoch_warn_seconds
self.step_warn_seconds = step_warn_seconds
self.training_warn_seconds = training_warn_seconds
def run_on_epoch_end(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_epoch_end(**kwargs)
def run_on_epoch_start(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_epoch_start(**kwargs)
def run_on_training_start(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_training_start(**kwargs)
def run_on_training_end(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_training_end(**kwargs)
def run_on_step_start(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_step_start(**kwargs)
def run_on_step_end(self, **kwargs):
for plugin in self.plugins:
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_step_end(**kwargs)

View File

@ -733,7 +733,6 @@ def main(args):
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
@ -753,13 +752,19 @@ def main(args):
plugins = [load_plugin(name) for name in args.plugins]
else:
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
for epoch in range(args.max_epochs):
for plugin in plugins:
plugin.on_epoch_start(epoch, global_step)
plugin_runner.run_on_epoch_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root)
loss_epoch = []
epoch_start_time = time.time()
@ -776,6 +781,11 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
plugin_runner.run_on_step_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
batch=batch)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
@ -840,6 +850,13 @@ def main(args):
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root,
batch=batch)
del batch
global_step += 1
# end of step
@ -858,9 +875,13 @@ def main(args):
loss_epoch = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
for plugin in plugins:
plugin.on_epoch_end(epoch, global_step)
gc.collect()
plugin_runner.run_on_epoch_end(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root)
gc.collect()
# end of epoch
# end of training