From 43984f2ad34df02b7dd122f697f8e231523eaf44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Rold=C3=A1n?= Date: Mon, 17 Jul 2023 01:33:52 -0300 Subject: [PATCH] Add --keep_tags to keep first N tags fixed on shuffle --- data/every_dream.py | 4 +++- data/image_train_item.py | 15 +++++++++++---- train.py | 5 +++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index 3cd84f9..56f3a3f 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset): seed=555, tokenizer=None, shuffle_tags=False, + keep_tags=0, rated_dataset=False, rated_dataset_dropout_target=0.5, name='train' @@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset): self.tokenizer = tokenizer self.max_token_length = self.tokenizer.model_max_length self.shuffle_tags = shuffle_tags + self.keep_tags = keep_tags self.seed = seed self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target @@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset): ) if self.shuffle_tags or train_item["shuffle_tags"]: - example["caption"] = train_item["caption"].get_shuffled_caption(self.seed) + example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags) else: example["caption"] = train_item["caption"].get_caption() diff --git a/data/image_train_item.py b/data/image_train_item.py index 0af9b4d..27ad6ef 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -56,7 +56,7 @@ class ImageCaption: def rating(self) -> float: return self.__rating - def get_shuffled_caption(self, seed: int) -> str: + def get_shuffled_caption(self, seed: int, keep_tags: int) -> str: """ returns the caption a string with a random selection of the tags in random order :param seed used to initialize the randomizer @@ -74,7 +74,7 @@ class ImageCaption: if self.__use_weights: tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) else: - tags_caption = self.__get_shuffled_tags(seed, self.__tags) + tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags) return self.__main_prompt + ", " + tags_caption return self.__main_prompt @@ -111,8 +111,15 @@ class ImageCaption: return caption @staticmethod - def __get_shuffled_tags(seed: int, tags: list[str]) -> str: - random.Random(seed).shuffle(tags) + def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str: + tags = tags.copy() + + if len(tags) > keep_tags: + fixed_tags = tags[:keep_tags] + rest = tags[keep_tags:] + random.Random(seed).shuffle(rest) + tags = fixed_tags + rest + return ", ".join(tags) class ImageTrainItem: diff --git a/train.py b/train.py index 58a2e3f..bc94961 100644 --- a/train.py +++ b/train.py @@ -350,6 +350,9 @@ def setup_args(args): if not args.shuffle_tags: args.shuffle_tags = False + if not args.keep_tags: + args.keep_tags = 0 + args.clip_skip = max(min(4, args.clip_skip), 0) if args.useadam8bit: @@ -779,6 +782,7 @@ def main(args): tokenizer=tokenizer, seed = seed, shuffle_tags=args.shuffle_tags, + keep_tags=args.keep_tags, rated_dataset=args.rated_dataset, rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) ) @@ -1326,6 +1330,7 @@ if __name__ == "__main__": argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") + argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")