Add --keep_tags to keep first N tags fixed on shuffle
This commit is contained in:
parent
6c8d15daab
commit
43984f2ad3
|
@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset):
|
||||||
seed=555,
|
seed=555,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
shuffle_tags=False,
|
shuffle_tags=False,
|
||||||
|
keep_tags=0,
|
||||||
rated_dataset=False,
|
rated_dataset=False,
|
||||||
rated_dataset_dropout_target=0.5,
|
rated_dataset_dropout_target=0.5,
|
||||||
name='train'
|
name='train'
|
||||||
|
@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_token_length = self.tokenizer.model_max_length
|
self.max_token_length = self.tokenizer.model_max_length
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
|
self.keep_tags = keep_tags
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.rated_dataset = rated_dataset
|
self.rated_dataset = rated_dataset
|
||||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||||
|
@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shuffle_tags or train_item["shuffle_tags"]:
|
if self.shuffle_tags or train_item["shuffle_tags"]:
|
||||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
|
||||||
else:
|
else:
|
||||||
example["caption"] = train_item["caption"].get_caption()
|
example["caption"] = train_item["caption"].get_caption()
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ class ImageCaption:
|
||||||
def rating(self) -> float:
|
def rating(self) -> float:
|
||||||
return self.__rating
|
return self.__rating
|
||||||
|
|
||||||
def get_shuffled_caption(self, seed: int) -> str:
|
def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
|
||||||
"""
|
"""
|
||||||
returns the caption a string with a random selection of the tags in random order
|
returns the caption a string with a random selection of the tags in random order
|
||||||
:param seed used to initialize the randomizer
|
:param seed used to initialize the randomizer
|
||||||
|
@ -74,7 +74,7 @@ class ImageCaption:
|
||||||
if self.__use_weights:
|
if self.__use_weights:
|
||||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||||
else:
|
else:
|
||||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
|
||||||
|
|
||||||
return self.__main_prompt + ", " + tags_caption
|
return self.__main_prompt + ", " + tags_caption
|
||||||
return self.__main_prompt
|
return self.__main_prompt
|
||||||
|
@ -111,8 +111,15 @@ class ImageCaption:
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
|
def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
|
||||||
random.Random(seed).shuffle(tags)
|
tags = tags.copy()
|
||||||
|
|
||||||
|
if len(tags) > keep_tags:
|
||||||
|
fixed_tags = tags[:keep_tags]
|
||||||
|
rest = tags[keep_tags:]
|
||||||
|
random.Random(seed).shuffle(rest)
|
||||||
|
tags = fixed_tags + rest
|
||||||
|
|
||||||
return ", ".join(tags)
|
return ", ".join(tags)
|
||||||
|
|
||||||
class ImageTrainItem:
|
class ImageTrainItem:
|
||||||
|
|
5
train.py
5
train.py
|
@ -350,6 +350,9 @@ def setup_args(args):
|
||||||
if not args.shuffle_tags:
|
if not args.shuffle_tags:
|
||||||
args.shuffle_tags = False
|
args.shuffle_tags = False
|
||||||
|
|
||||||
|
if not args.keep_tags:
|
||||||
|
args.keep_tags = 0
|
||||||
|
|
||||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||||
|
|
||||||
if args.useadam8bit:
|
if args.useadam8bit:
|
||||||
|
@ -779,6 +782,7 @@ def main(args):
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
shuffle_tags=args.shuffle_tags,
|
shuffle_tags=args.shuffle_tags,
|
||||||
|
keep_tags=args.keep_tags,
|
||||||
rated_dataset=args.rated_dataset,
|
rated_dataset=args.rated_dataset,
|
||||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||||
)
|
)
|
||||||
|
@ -1326,6 +1330,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||||
|
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
|
||||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||||
|
|
Loading…
Reference in New Issue