Add --keep_tags to keep first N tags fixed on shuffle

This commit is contained in:
Gabriel Roldán 2023-07-17 01:33:52 -03:00 committed by Gabriel Roldan
parent 6c8d15daab
commit 43984f2ad3
No known key found for this signature in database
GPG Key ID: 6FAD6D4A395EB862
3 changed files with 19 additions and 5 deletions

View File

@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset):
seed=555, seed=555,
tokenizer=None, tokenizer=None,
shuffle_tags=False, shuffle_tags=False,
keep_tags=0,
rated_dataset=False, rated_dataset=False,
rated_dataset_dropout_target=0.5, rated_dataset_dropout_target=0.5,
name='train' name='train'
@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_token_length = self.tokenizer.model_max_length self.max_token_length = self.tokenizer.model_max_length
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.keep_tags = keep_tags
self.seed = seed self.seed = seed
self.rated_dataset = rated_dataset self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target self.rated_dataset_dropout_target = rated_dataset_dropout_target
@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset):
) )
if self.shuffle_tags or train_item["shuffle_tags"]: if self.shuffle_tags or train_item["shuffle_tags"]:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed) example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
else: else:
example["caption"] = train_item["caption"].get_caption() example["caption"] = train_item["caption"].get_caption()

View File

@ -56,7 +56,7 @@ class ImageCaption:
def rating(self) -> float: def rating(self) -> float:
return self.__rating return self.__rating
def get_shuffled_caption(self, seed: int) -> str: def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
""" """
returns the caption a string with a random selection of the tags in random order returns the caption a string with a random selection of the tags in random order
:param seed used to initialize the randomizer :param seed used to initialize the randomizer
@ -74,7 +74,7 @@ class ImageCaption:
if self.__use_weights: if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else: else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags) tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
return self.__main_prompt + ", " + tags_caption return self.__main_prompt + ", " + tags_caption
return self.__main_prompt return self.__main_prompt
@ -111,8 +111,15 @@ class ImageCaption:
return caption return caption
@staticmethod @staticmethod
def __get_shuffled_tags(seed: int, tags: list[str]) -> str: def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
random.Random(seed).shuffle(tags) tags = tags.copy()
if len(tags) > keep_tags:
fixed_tags = tags[:keep_tags]
rest = tags[keep_tags:]
random.Random(seed).shuffle(rest)
tags = fixed_tags + rest
return ", ".join(tags) return ", ".join(tags)
class ImageTrainItem: class ImageTrainItem:

View File

@ -350,6 +350,9 @@ def setup_args(args):
if not args.shuffle_tags: if not args.shuffle_tags:
args.shuffle_tags = False args.shuffle_tags = False
if not args.keep_tags:
args.keep_tags = 0
args.clip_skip = max(min(4, args.clip_skip), 0) args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit: if args.useadam8bit:
@ -779,6 +782,7 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
seed = seed, seed = seed,
shuffle_tags=args.shuffle_tags, shuffle_tags=args.shuffle_tags,
keep_tags=args.keep_tags,
rated_dataset=args.rated_dataset, rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
) )
@ -1326,6 +1330,7 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.") argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")