add dataloader plugin hooks for caption and pil image

This commit is contained in:
Victor Hall 2023-12-20 14:55:50 -05:00
parent 1236329677
commit dfcc9e7f41
3 changed files with 32 additions and 12 deletions

View File

@ -42,6 +42,7 @@ class EveryDreamBatch(Dataset):
tokenizer=None,
shuffle_tags=False,
keep_tags=0,
plugin_runner=None,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
name='train'
@ -56,6 +57,7 @@ class EveryDreamBatch(Dataset):
self.max_token_length = self.tokenizer.model_max_length
self.shuffle_tags = shuffle_tags
self.keep_tags = keep_tags
self.plugin_runner = plugin_runner
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
@ -101,6 +103,8 @@ class EveryDreamBatch(Dataset):
example["caption"] = train_item["caption"].get_caption()
example["image"] = image_transforms(train_item["image"])
example["image"] = self.plugin_runner.transform_pil_image(example["image"])
example["caption"] = self.plugin_runner.transform_caption(example["caption"])
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
example["tokens"] = self.tokenizer(example["caption"],

View File

@ -3,6 +3,7 @@ import importlib
import logging
import time
import warnings
from PIL import Image
class BasePlugin:
def on_epoch_start(self, **kwargs):
@ -17,8 +18,10 @@ class BasePlugin:
pass
def on_step_end(self, **kwargs):
pass
def transform_caption(self, caption:str):
return caption
def transform_pil_image(self, img:Image):
return img
def load_plugin(plugin_path):
print(f" - Attempting to load plugin: {plugin_path}")
@ -89,3 +92,15 @@ class PluginRunner:
for plugin in self.plugins:
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
plugin.on_step_end(**kwargs)
def run_transform_caption(self, caption):
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"):
for plugin in self.plugins:
caption = plugin.transform_caption(caption)
return caption
def run_transform_pil_image(self, img):
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_pil_image"):
for plugin in self.plugins:
img = plugin.transform_pil_image(img)
return img

View File

@ -775,6 +775,16 @@ def main(args):
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
from plugins.plugins import load_plugin
if args.plugins is not None:
plugins = [load_plugin(name) for name in args.plugins]
else:
logging.info("No plugins specified")
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
seed=seed,
@ -790,6 +800,7 @@ def main(args):
seed = seed,
shuffle_tags=args.shuffle_tags,
keep_tags=args.keep_tags,
plugin_runner=plugin_runner,
rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
)
@ -1082,16 +1093,6 @@ def main(args):
_, batch = next(enumerate(train_dataloader))
generate_samples(global_step=0, batch=batch)
from plugins.plugins import load_plugin
if args.plugins is not None:
plugins = [load_plugin(name) for name in args.plugins]
else:
logging.info("No plugins specified")
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
def make_current_ed_state() -> EveryDreamTrainingState:
return EveryDreamTrainingState(optimizer=ed_optimizer,
train_batch=train_batch,