From dfcc9e7f41b6420c94256b9fd057e6e9a01ed802 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Wed, 20 Dec 2023 14:55:50 -0500 Subject: [PATCH] add dataloader plugin hooks for caption and pil image --- data/every_dream.py | 4 ++++ plugins/plugins.py | 19 +++++++++++++++++-- train.py | 21 +++++++++++---------- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index d8cf63d..5b78cdb 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -42,6 +42,7 @@ class EveryDreamBatch(Dataset): tokenizer=None, shuffle_tags=False, keep_tags=0, + plugin_runner=None, rated_dataset=False, rated_dataset_dropout_target=0.5, name='train' @@ -56,6 +57,7 @@ class EveryDreamBatch(Dataset): self.max_token_length = self.tokenizer.model_max_length self.shuffle_tags = shuffle_tags self.keep_tags = keep_tags + self.plugin_runner = plugin_runner self.seed = seed self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target @@ -101,6 +103,8 @@ class EveryDreamBatch(Dataset): example["caption"] = train_item["caption"].get_caption() example["image"] = image_transforms(train_item["image"]) + example["image"] = self.plugin_runner.transform_pil_image(example["image"]) + example["caption"] = self.plugin_runner.transform_caption(example["caption"]) if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)): example["tokens"] = self.tokenizer(example["caption"], diff --git a/plugins/plugins.py b/plugins/plugins.py index dc65ff2..46a9f1c 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -3,6 +3,7 @@ import importlib import logging import time import warnings +from PIL import Image class BasePlugin: def on_epoch_start(self, **kwargs): @@ -17,8 +18,10 @@ class BasePlugin: pass def on_step_end(self, **kwargs): pass - - + def transform_caption(self, caption:str): + return caption + def transform_pil_image(self, img:Image): + return img def load_plugin(plugin_path): print(f" - Attempting to load plugin: {plugin_path}") @@ -89,3 +92,15 @@ class PluginRunner: for plugin in self.plugins: with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'): plugin.on_step_end(**kwargs) + + def run_transform_caption(self, caption): + with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"): + for plugin in self.plugins: + caption = plugin.transform_caption(caption) + return caption + + def run_transform_pil_image(self, img): + with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_pil_image"): + for plugin in self.plugins: + img = plugin.transform_pil_image(img) + return img diff --git a/train.py b/train.py index 2858dc4..749c5c7 100644 --- a/train.py +++ b/train.py @@ -775,6 +775,16 @@ def main(args): report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size) + from plugins.plugins import load_plugin + if args.plugins is not None: + plugins = [load_plugin(name) for name in args.plugins] + else: + logging.info("No plugins specified") + plugins = [] + + from plugins.plugins import PluginRunner + plugin_runner = PluginRunner(plugins=plugins) + data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, seed=seed, @@ -790,6 +800,7 @@ def main(args): seed = seed, shuffle_tags=args.shuffle_tags, keep_tags=args.keep_tags, + plugin_runner=plugin_runner, rated_dataset=args.rated_dataset, rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) ) @@ -1082,16 +1093,6 @@ def main(args): _, batch = next(enumerate(train_dataloader)) generate_samples(global_step=0, batch=batch) - from plugins.plugins import load_plugin - if args.plugins is not None: - plugins = [load_plugin(name) for name in args.plugins] - else: - logging.info("No plugins specified") - plugins = [] - - from plugins.plugins import PluginRunner - plugin_runner = PluginRunner(plugins=plugins) - def make_current_ed_state() -> EveryDreamTrainingState: return EveryDreamTrainingState(optimizer=ed_optimizer, train_batch=train_batch,