diff --git a/data/every_dream.py b/data/every_dream.py index e886baa..f761e3c 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -25,6 +25,8 @@ from torchvision import transforms from transformers import CLIPTokenizer import torch.nn.functional as F +from plugins.plugins import PluginRunner + class EveryDreamBatch(Dataset): """ data_loader: `DataLoaderMultiAspect` object @@ -42,7 +44,7 @@ class EveryDreamBatch(Dataset): tokenizer=None, shuffle_tags=False, keep_tags=0, - plugin_runner=None, + plugin_runner:PluginRunner=None, rated_dataset=False, rated_dataset_dropout_target=0.5, name='train' @@ -102,9 +104,9 @@ class EveryDreamBatch(Dataset): else: example["caption"] = train_item["caption"].get_caption() - example["image"] = self.plugin_runner.transform_pil_image(example["image"]) - example["image"] = image_transforms(train_item["image"]) - example["caption"] = self.plugin_runner.transform_caption(example["caption"]) + example["image"] = self.plugin_runner.run_transform_pil_image(train_item["image"]) + example["image"] = image_transforms(example["image"]) + example["caption"] = self.plugin_runner.run_transform_caption(example["caption"]) if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)): example["tokens"] = self.tokenizer(example["caption"],