improve plugins
This commit is contained in:
parent
6907c01b51
commit
42c417171d
|
@ -1,23 +1,91 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
|
||||
class BasePlugin:
|
||||
def on_epoch_start(self, **kwargs):
|
||||
pass
|
||||
def on_epoch_end(self, **kwargs):
|
||||
pass
|
||||
def on_training_start(self, **kwargs):
|
||||
pass
|
||||
def on_training_end(self, **kwargs):
|
||||
pass
|
||||
def on_step_start(self, **kwargs):
|
||||
pass
|
||||
def on_step_end(self, **kwargs):
|
||||
pass
|
||||
|
||||
class ExampleLoggingPlugin(BasePlugin):
|
||||
def on_epoch_start(self, **kwargs):
|
||||
logging.info(f"Epoch {kwargs['epoch']} starting")
|
||||
def on_epoch_end(self, **kwargs):
|
||||
logging.info(f"Epoch {kwargs['epoch']} finished")
|
||||
|
||||
def load_plugin(plugin_name):
|
||||
module = importlib.import_module(plugin_name)
|
||||
|
||||
def load_plugin(plugin_path):
|
||||
print(f" - Attempting to load plugin: {plugin_path}")
|
||||
module_path = '.'.join(plugin_path.split('.')[:-1])
|
||||
module = importlib.import_module(module_path)
|
||||
plugin_name = plugin_path.split('.')[-1]
|
||||
|
||||
plugin_class = getattr(module, plugin_name)
|
||||
|
||||
if not issubclass(plugin_class, BasePlugin):
|
||||
raise TypeError(f'{plugin_name} is not a valid plugin')
|
||||
logging.info(f"Plugin {plugin_name} loaded")
|
||||
raise TypeError(f'{plugin_path} is not a valid plugin')
|
||||
logging.info(f" - Plugin {plugin_path} loaded to {plugin_class}")
|
||||
return plugin_class()
|
||||
|
||||
class Timer:
|
||||
def __init__(self, warn_seconds, label='plugin'):
|
||||
self.warn_seconds = warn_seconds
|
||||
self.label = label
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
elapsed_time = time.time() - self.start
|
||||
if elapsed_time > self.warn_seconds:
|
||||
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds')
|
||||
|
||||
|
||||
class PluginRunner:
|
||||
def __init__(self, plugins: list, epoch_warn_seconds=5, step_warn_seconds=0.5, training_warn_seconds=20):
|
||||
"""
|
||||
plugins: list of plugins to run
|
||||
epoch_warn_seconds: warn if any epoch start/end call takes longer than this
|
||||
step_warn_seconds: warn if any step start/end call takes longer than this
|
||||
training_warn_seconds: warn if any training start/end call take longer than this
|
||||
"""
|
||||
self.plugins = plugins
|
||||
self.epoch_warn_seconds = epoch_warn_seconds
|
||||
self.step_warn_seconds = step_warn_seconds
|
||||
self.training_warn_seconds = training_warn_seconds
|
||||
|
||||
def run_on_epoch_end(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_epoch_end(**kwargs)
|
||||
|
||||
def run_on_epoch_start(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.epoch_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_epoch_start(**kwargs)
|
||||
|
||||
def run_on_training_start(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_training_start(**kwargs)
|
||||
|
||||
def run_on_training_end(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.training_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_training_end(**kwargs)
|
||||
|
||||
def run_on_step_start(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_step_start(**kwargs)
|
||||
|
||||
def run_on_step_end(self, **kwargs):
|
||||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_step_end(**kwargs)
|
||||
|
|
33
train.py
33
train.py
|
@ -733,7 +733,6 @@ def main(args):
|
|||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def make_save_path(epoch, global_step, prepend=""):
|
||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
|
@ -753,13 +752,19 @@ def main(args):
|
|||
plugins = [load_plugin(name) for name in args.plugins]
|
||||
else:
|
||||
plugins = []
|
||||
|
||||
from plugins.plugins import PluginRunner
|
||||
plugin_runner = PluginRunner(plugins=plugins)
|
||||
|
||||
try:
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
for plugin in plugins:
|
||||
plugin.on_epoch_start(epoch, global_step)
|
||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root)
|
||||
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
|
@ -776,6 +781,11 @@ def main(args):
|
|||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
plugin_runner.run_on_step_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
batch=batch)
|
||||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
|
||||
|
@ -840,6 +850,13 @@ def main(args):
|
|||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
plugin_runner.run_on_step_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root,
|
||||
batch=batch)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
# end of step
|
||||
|
@ -858,9 +875,13 @@ def main(args):
|
|||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||
|
||||
for plugin in plugins:
|
||||
plugin.on_epoch_end(epoch, global_step)
|
||||
gc.collect()
|
||||
plugin_runner.run_on_epoch_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
|
Loading…
Reference in New Issue