error in dataloader plugin

This commit is contained in:
Victor Hall 2023-12-21 10:49:06 -05:00
parent 7a81182220
commit c7d3064029
1 changed files with 6 additions and 4 deletions

View File

@ -25,6 +25,8 @@ from torchvision import transforms
from transformers import CLIPTokenizer
import torch.nn.functional as F
from plugins.plugins import PluginRunner
class EveryDreamBatch(Dataset):
"""
data_loader: `DataLoaderMultiAspect` object
@ -42,7 +44,7 @@ class EveryDreamBatch(Dataset):
tokenizer=None,
shuffle_tags=False,
keep_tags=0,
plugin_runner=None,
plugin_runner:PluginRunner=None,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
name='train'
@ -102,9 +104,9 @@ class EveryDreamBatch(Dataset):
else:
example["caption"] = train_item["caption"].get_caption()
example["image"] = self.plugin_runner.transform_pil_image(example["image"])
example["image"] = image_transforms(train_item["image"])
example["caption"] = self.plugin_runner.transform_caption(example["caption"])
example["image"] = self.plugin_runner.run_transform_pil_image(train_item["image"])
example["image"] = image_transforms(example["image"])
example["caption"] = self.plugin_runner.run_transform_caption(example["caption"])
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
example["tokens"] = self.tokenizer(example["caption"],