shuffle tags arg

This commit is contained in:
Victor Hall 2023-01-06 19:12:52 -05:00
parent ced593d929
commit 98f9a7302d
4 changed files with 30 additions and 6 deletions

View File

@ -50,6 +50,7 @@ class EveryDreamBatch(Dataset):
log_folder=None,
retain_contrast=False,
write_schedule=False,
shuffle_tags=False,
):
self.data_root = data_root
self.batch_size = batch_size
@ -63,6 +64,8 @@ class EveryDreamBatch(Dataset):
self.max_token_length = self.tokenizer.model_max_length
self.retain_contrast = retain_contrast
self.write_schedule = write_schedule
self.shuffle_tags = shuffle_tags
self.seed = seed
if seed == -1:
seed = random.randint(0, 99999)
@ -131,6 +134,12 @@ class EveryDreamBatch(Dataset):
]
)
if self.shuffle_tags and "," in train_item['caption']:
tags = train_item["caption"].split(",")
random.Random(self.seed).shuffle(tags)
self.seed += 1
train_item["caption"] = ", ".join(tags)
example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout:
@ -145,6 +154,7 @@ class EveryDreamBatch(Dataset):
padding="max_length",
max_length=self.tokenizer.model_max_length,
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
example["caption"] = train_item["caption"] # for sampling if needed
example["runt_size"] = train_item["runt_size"]

View File

@ -115,6 +115,14 @@ If you wish for your training images to be randomly flipped horizontally, use th
This is useful for styles or other training that is not asymmetrical. It is not suggested for training specific human faces as it may wash out facial features as real people typically have at least some asymmetric facial features. It may also cause problems if you are training fictional characters with asymmetrical outfits, such as washing out the asymmetries in the outfit. It is also not suggested if any of your captions included directions like "left" or "right". Default is 0.0 (no flipping)
# Shuffle tags
For those training booru tagged models, you can use this arg to randomly (but deterministicly unless you use `--seed -1`) all the CSV tags in your captions
--shuffle_tags ^
This simple chops the captions in to parts based on the commas and shuffles the order.
# Stuff you probably don't need to mess with, but well here it is:
## Clip skip

View File

@ -31,6 +31,7 @@
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": false

View File

@ -208,6 +208,9 @@ def main(args):
if args.ed1_mode:
args.disable_xformers = True
if not args.shuffle_tags:
args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0)
@ -448,6 +451,7 @@ def main(args):
seed = seed,
log_folder=log_folder,
write_schedule=args.write_schedule,
shuffle_tags=args.shuffle_tags,
)
torch.cuda.benchmark = False
@ -655,6 +659,12 @@ def main(args):
del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
loss.backward()
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing
@ -666,12 +676,6 @@ def main(args):
if param.grad is not None:
param.grad *= grad_scale
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
loss.backward()
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
@ -808,6 +812,7 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")