Fix cond_dropout and rating handling

This commit is contained in:
Augusto de la Torre 2023-03-13 00:36:59 +01:00
parent ce8f8751c9
commit 7e20a74586
2 changed files with 4 additions and 2 deletions

View File

@ -230,7 +230,7 @@ class Dataset:
caption = ImageCaption(
main_prompt=next(iter(sorted(config.main_prompts))),
rating=config.rating,
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,
max_target_length=config.max_caption_length,

View File

@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
example["image"] = image_transforms(train_item["image"])
if random.random() > (train_item.cond_dropout or self.conditional_dropout):
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",
@ -132,6 +132,8 @@ class EveryDreamBatch(Dataset):
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak
example["caption"] = image_train_tmp.caption
if image_train_tmp.cond_dropout is not None:
example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size
return example