shuffle tags arg
This commit is contained in:
parent
ced593d929
commit
98f9a7302d
|
@ -50,6 +50,7 @@ class EveryDreamBatch(Dataset):
|
|||
log_folder=None,
|
||||
retain_contrast=False,
|
||||
write_schedule=False,
|
||||
shuffle_tags=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
@ -63,6 +64,8 @@ class EveryDreamBatch(Dataset):
|
|||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.retain_contrast = retain_contrast
|
||||
self.write_schedule = write_schedule
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999)
|
||||
|
@ -131,6 +134,12 @@ class EveryDreamBatch(Dataset):
|
|||
]
|
||||
)
|
||||
|
||||
if self.shuffle_tags and "," in train_item['caption']:
|
||||
tags = train_item["caption"].split(",")
|
||||
random.Random(self.seed).shuffle(tags)
|
||||
self.seed += 1
|
||||
train_item["caption"] = ", ".join(tags)
|
||||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
if random.random() > self.conditional_dropout:
|
||||
|
@ -145,6 +154,7 @@ class EveryDreamBatch(Dataset):
|
|||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
example["tokens"] = torch.tensor(example["tokens"])
|
||||
example["caption"] = train_item["caption"] # for sampling if needed
|
||||
example["runt_size"] = train_item["runt_size"]
|
||||
|
|
|
@ -115,6 +115,14 @@ If you wish for your training images to be randomly flipped horizontally, use th
|
|||
|
||||
This is useful for styles or other training that is not asymmetrical. It is not suggested for training specific human faces as it may wash out facial features as real people typically have at least some asymmetric facial features. It may also cause problems if you are training fictional characters with asymmetrical outfits, such as washing out the asymmetries in the outfit. It is also not suggested if any of your captions included directions like "left" or "right". Default is 0.0 (no flipping)
|
||||
|
||||
# Shuffle tags
|
||||
|
||||
For those training booru tagged models, you can use this arg to randomly (but deterministicly unless you use `--seed -1`) all the CSV tags in your captions
|
||||
|
||||
--shuffle_tags ^
|
||||
|
||||
This simple chops the captions in to parts based on the commas and shuffles the order.
|
||||
|
||||
# Stuff you probably don't need to mess with, but well here it is:
|
||||
|
||||
## Clip skip
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
"save_optimizer": false,
|
||||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"wandb": false,
|
||||
"write_schedule": false
|
||||
|
|
17
train.py
17
train.py
|
@ -208,6 +208,9 @@ def main(args):
|
|||
|
||||
if args.ed1_mode:
|
||||
args.disable_xformers = True
|
||||
|
||||
if not args.shuffle_tags:
|
||||
args.shuffle_tags = False
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
|
@ -448,6 +451,7 @@ def main(args):
|
|||
seed = seed,
|
||||
log_folder=log_folder,
|
||||
write_schedule=args.write_schedule,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
)
|
||||
|
||||
torch.cuda.benchmark = False
|
||||
|
@ -655,6 +659,12 @@ def main(args):
|
|||
|
||||
del target, model_pred
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
loss.backward()
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
grad_scale = batch["runt_size"] / args.batch_size
|
||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||
|
@ -666,12 +676,6 @@ def main(args):
|
|||
if param.grad is not None:
|
||||
param.grad *= grad_scale
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
loss.backward()
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
@ -808,6 +812,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
|
|
Loading…
Reference in New Issue