add dataloader plugin hooks for caption and pil image
This commit is contained in:
parent
1236329677
commit
dfcc9e7f41
|
@ -42,6 +42,7 @@ class EveryDreamBatch(Dataset):
|
|||
tokenizer=None,
|
||||
shuffle_tags=False,
|
||||
keep_tags=0,
|
||||
plugin_runner=None,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
|
@ -56,6 +57,7 @@ class EveryDreamBatch(Dataset):
|
|||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.keep_tags = keep_tags
|
||||
self.plugin_runner = plugin_runner
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
@ -101,6 +103,8 @@ class EveryDreamBatch(Dataset):
|
|||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
example["image"] = self.plugin_runner.transform_pil_image(example["image"])
|
||||
example["caption"] = self.plugin_runner.transform_caption(example["caption"])
|
||||
|
||||
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
|
|
|
@ -3,6 +3,7 @@ import importlib
|
|||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from PIL import Image
|
||||
|
||||
class BasePlugin:
|
||||
def on_epoch_start(self, **kwargs):
|
||||
|
@ -17,8 +18,10 @@ class BasePlugin:
|
|||
pass
|
||||
def on_step_end(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def transform_caption(self, caption:str):
|
||||
return caption
|
||||
def transform_pil_image(self, img:Image):
|
||||
return img
|
||||
|
||||
def load_plugin(plugin_path):
|
||||
print(f" - Attempting to load plugin: {plugin_path}")
|
||||
|
@ -89,3 +92,15 @@ class PluginRunner:
|
|||
for plugin in self.plugins:
|
||||
with Timer(warn_seconds=self.step_warn_seconds, label=f'{plugin.__class__.__name__}'):
|
||||
plugin.on_step_end(**kwargs)
|
||||
|
||||
def run_transform_caption(self, caption):
|
||||
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_caption"):
|
||||
for plugin in self.plugins:
|
||||
caption = plugin.transform_caption(caption)
|
||||
return caption
|
||||
|
||||
def run_transform_pil_image(self, img):
|
||||
with Timer(warn_seconds=self.step_warn_seconds, label="plugin.transform_pil_image"):
|
||||
for plugin in self.plugins:
|
||||
img = plugin.transform_pil_image(img)
|
||||
return img
|
||||
|
|
21
train.py
21
train.py
|
@ -775,6 +775,16 @@ def main(args):
|
|||
|
||||
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
|
||||
|
||||
from plugins.plugins import load_plugin
|
||||
if args.plugins is not None:
|
||||
plugins = [load_plugin(name) for name in args.plugins]
|
||||
else:
|
||||
logging.info("No plugins specified")
|
||||
plugins = []
|
||||
|
||||
from plugins.plugins import PluginRunner
|
||||
plugin_runner = PluginRunner(plugins=plugins)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
|
@ -790,6 +800,7 @@ def main(args):
|
|||
seed = seed,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
keep_tags=args.keep_tags,
|
||||
plugin_runner=plugin_runner,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
@ -1082,16 +1093,6 @@ def main(args):
|
|||
_, batch = next(enumerate(train_dataloader))
|
||||
generate_samples(global_step=0, batch=batch)
|
||||
|
||||
from plugins.plugins import load_plugin
|
||||
if args.plugins is not None:
|
||||
plugins = [load_plugin(name) for name in args.plugins]
|
||||
else:
|
||||
logging.info("No plugins specified")
|
||||
plugins = []
|
||||
|
||||
from plugins.plugins import PluginRunner
|
||||
plugin_runner = PluginRunner(plugins=plugins)
|
||||
|
||||
def make_current_ed_state() -> EveryDreamTrainingState:
|
||||
return EveryDreamTrainingState(optimizer=ed_optimizer,
|
||||
train_batch=train_batch,
|
||||
|
|
Loading…
Reference in New Issue