Add --keep_tags to keep first N tags fixed on shuffle
This commit is contained in:
parent
6c8d15daab
commit
43984f2ad3
|
@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset):
|
|||
seed=555,
|
||||
tokenizer=None,
|
||||
shuffle_tags=False,
|
||||
keep_tags=0,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
|
@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset):
|
|||
self.tokenizer = tokenizer
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.keep_tags = keep_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset):
|
|||
)
|
||||
|
||||
if self.shuffle_tags or train_item["shuffle_tags"]:
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
|
||||
else:
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ class ImageCaption:
|
|||
def rating(self) -> float:
|
||||
return self.__rating
|
||||
|
||||
def get_shuffled_caption(self, seed: int) -> str:
|
||||
def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
|
||||
"""
|
||||
returns the caption a string with a random selection of the tags in random order
|
||||
:param seed used to initialize the randomizer
|
||||
|
@ -74,7 +74,7 @@ class ImageCaption:
|
|||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
|
||||
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt
|
||||
|
@ -111,8 +111,15 @@ class ImageCaption:
|
|||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
|
||||
random.Random(seed).shuffle(tags)
|
||||
def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
|
||||
tags = tags.copy()
|
||||
|
||||
if len(tags) > keep_tags:
|
||||
fixed_tags = tags[:keep_tags]
|
||||
rest = tags[keep_tags:]
|
||||
random.Random(seed).shuffle(rest)
|
||||
tags = fixed_tags + rest
|
||||
|
||||
return ", ".join(tags)
|
||||
|
||||
class ImageTrainItem:
|
||||
|
|
5
train.py
5
train.py
|
@ -350,6 +350,9 @@ def setup_args(args):
|
|||
if not args.shuffle_tags:
|
||||
args.shuffle_tags = False
|
||||
|
||||
if not args.keep_tags:
|
||||
args.keep_tags = 0
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
|
@ -779,6 +782,7 @@ def main(args):
|
|||
tokenizer=tokenizer,
|
||||
seed = seed,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
keep_tags=args.keep_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
@ -1326,6 +1330,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
|
|
Loading…
Reference in New Issue