error in dataloader plugin
This commit is contained in:
parent
7a81182220
commit
c7d3064029
|
@ -25,6 +25,8 @@ from torchvision import transforms
|
|||
from transformers import CLIPTokenizer
|
||||
import torch.nn.functional as F
|
||||
|
||||
from plugins.plugins import PluginRunner
|
||||
|
||||
class EveryDreamBatch(Dataset):
|
||||
"""
|
||||
data_loader: `DataLoaderMultiAspect` object
|
||||
|
@ -42,7 +44,7 @@ class EveryDreamBatch(Dataset):
|
|||
tokenizer=None,
|
||||
shuffle_tags=False,
|
||||
keep_tags=0,
|
||||
plugin_runner=None,
|
||||
plugin_runner:PluginRunner=None,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
|
@ -102,9 +104,9 @@ class EveryDreamBatch(Dataset):
|
|||
else:
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
example["image"] = self.plugin_runner.transform_pil_image(example["image"])
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
example["caption"] = self.plugin_runner.transform_caption(example["caption"])
|
||||
example["image"] = self.plugin_runner.run_transform_pil_image(train_item["image"])
|
||||
example["image"] = image_transforms(example["image"])
|
||||
example["caption"] = self.plugin_runner.run_transform_caption(example["caption"])
|
||||
|
||||
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
|
|
Loading…
Reference in New Issue