json driven args, update run batch compensation

This commit is contained in:
Victor Hall 2023-01-03 14:27:26 -05:00
parent a7dcf43f7c
commit 65c5fd5ccb
4 changed files with 110 additions and 50 deletions

View File

@ -92,7 +92,7 @@ class EveryDreamBatch(Dataset):
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(self.image_train_items)):
try:
f.write(f"step:{int(i / self.batch_size)}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")

View File

@ -11,15 +11,17 @@ You will need to place all your images and captions into a folder. Inside that
When you train, you will use "--data_root" to point to the root folder of your data. All images in that folder and its subfolders will be used for training.
If you wish to boost training on a particular folder of images, put a "multiply.txt" in that folder with a whole number in it, such as 2. This will multiply the number of times images in that specific folder are used for training per epoch. This is useful if you have two characters you want to train, separated into separate folders, but one character has fewer images.
# Data preparation
## Image size
The trainer will automatically fit your images to the best possible size. It is best to leave your images larger tham you may think for typical Stable Diffusion training. Even 4K images will be handled fine so just don't sweat it if you have large images. The only downside is they take a bit more disk space.
The trainer will automatically fit your images to the best possible size. It is best to leave your images larger tham you may think for typical Stable Diffusion training. Even 4K images will be handled fine so just don't sweat it if you have large images. The only downside is they take a bit more disk space. There is almost no performance impact for having higher resolution images.
Current recommendation is 1 megapixel (ex 1100x100, 1300x900, etc) or larger, but thinking ahead to future technology advancements you may wish to keep them at even larger resolutions. Again, don't worry about the trainer squeezing or cropping, it will handle it!
Current recommendation is 1 megapixel (ex 1024x1024, 1100x900, 1300x800, etc) or larger, but thinking ahead to future technology advancements you may wish to keep them at even larger resolutions. Again, don't worry about the trainer squeezing or cropping, it will handle it!
Aspect ratios up to 4:1 or 1:4 are supported. Again, just don't worry about this too much. The trainer will handle it.
Aspect ratios up to 4:1 or 1:4 are supported.
## Cropping
@ -29,14 +31,16 @@ You can crop your images in an image editor *if you need, in order to get good c
For most use cases, use a sane English sentence to describe the image. Try to put your character or main object name close to the start.
Those training anime models can use booru tags as well using other utilities to generate the captions.
### Styles
For style, consider adding a suffix on the caption that describes the style. Examples would be "by claude monet" or "in the style of gta box art" at the end of the caption. This will help the model learn recall style at inference time so you can style other subjects you did not train with the style. You may also consider "drawing of" or "painting of" at the start of the caption when appropriate.
Consider also including a style tag as above if you are training anything besides photos. For instance, if you are training a few characters from a video game you can consider "cloud strife holding a buster sword, screenshot from final fantasy for ps5" if you wish to capture the style of the game along with the characters.
Consider also including a style tag as above if you are training anything besides photos. For instance, if you are training a few characters from a video game you can consider "cloud strife holding a buster sword, screenshot from final fantasy for ps5" if you wish to capture the "style" of the game render along with the characters.
### Context
Include the surroundings and context in your captions. Ex. "cloud strife standing on a dirt path in midgar city slums district" Again, this will allow you to recall the "dirt path in midgar city slums district" style at inference time, and will even pick up on pieces of that like "midgar city" (if enough samples are present with similar words) as a style you can apply later!
Include the surroundings and context in your captions. Ex. "cloud strife standing on a dirt path in midgar city slums district" Again, this will allow you to recall the "dirt path in midgar city slums district" style at inference time, and will even pick up on pieces of that like "midgar city" (if enough samples are present with similar words) as a style or scenery you can apply later. This can extract additional value from your training besides just the character.
Also consider some basic mention of pose. ex. "clouds strife sitting on a blue wooden bench in front of a concrete wall" or "barrett wallace holding his fist in front of his face with an angry look on his face, looking at the camera." Captions can capture value not only for the character's look, but also for the pose, the background scene, and the camera angle. You can be creative here, there is a lot of potential!

36
train.json Normal file
View File

@ -0,0 +1,36 @@
{
"resume_ckpt": "sd_v1-5_vae",
"lr_scheduler": "cosine",
"lr_warmup_steps": null,
"lr_decay_steps": 0,
"log_step": 25,
"max_epochs": 50,
"ckpt_every_n_minutes": null,
"save_every_n_epochs": 20,
"lr": 4.5e-06,
"useadam8bit": true,
"project_name": "my_project",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"disable_textenc_training": false,
"batch_size": 7,
"clip_grad_norm": null,
"grad_accum": 5,
"clip_skip": 0,
"data_root": "X:\\mytrainingdata\\my_project_stuff",
"wandb": false,
"save_optimizer": false,
"resolution": 512,
"amp": false,
"cond_dropout": 0.04,
"logdir": "logs",
"save_ckpt_dir": null,
"scale_lr": false,
"seed": 555,
"flip_p": 0.0,
"gpuid": 0,
"write_schedule": true,
"gradient_checkpointing": true,
"disable_xformers": false,
"lowvram": false
}

108
train.py
View File

@ -110,10 +110,10 @@ def setup_local_logger(args):
json_config = json.dumps(vars(args), indent=2)
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logfilename = os.path.join(log_path, f"{args.project_name}-train{datetimestamp}.log")
with open(logfilename, "w") as f:
f.write(f"Training config:\n{json_config}\n")
with open(os.path.join(log_path, f"{args.project_name}-{datetimestamp}.json"), "w") as f:
f.write(f"{json_config}")
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
logging.basicConfig(filename=logfilename,
level=logging.INFO,
format="%(asctime)s %(message)s",
@ -600,7 +600,6 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
torch.cuda.empty_cache()
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
try:
@ -652,11 +651,18 @@ def main(args):
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
if batch["runt_size"]> 0:
loss = loss / (batch["runt_size"] / args.batch_size)
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing
for param in unet.parameters():
if param.grad is not None:
param.grad *= grad_scale
if text_encoder.training:
for param in text_encoder.parameters():
if param.grad is not None:
param.grad *= grad_scale
del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
@ -687,6 +693,7 @@ def main(args):
#log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae)
@ -702,7 +709,7 @@ def main(args):
del pipe
gc.collect()
torch.cuda.empty_cache()
torch.cuda.empty_cache()
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
@ -717,6 +724,7 @@ def main(args):
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
del loss, batch
# end of step
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
@ -735,7 +743,7 @@ def main(args):
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
logging.info(f"Total training time took {total_elapsed_time:.2f} seconds, total steps: {global_step}")
logging.info(f"Total training time took {total_elapsed_time/60:.2f} minutes, total steps: {global_step}")
logging.info(f"Average epoch time: {np.mean([t['time'] for t in epoch_times]):.2f} minutes")
except Exception as ex:
@ -752,41 +760,53 @@ def main(args):
if __name__ == "__main__":
supported_resolutions = [448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args()
args = argparser.parse_args()
if args.config is not None:
print(f"Loading training config from {args.config}, all other command options will be ignored!")
with open(args.config, 'rt') as f:
t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f))
args = argparser.parse_args(namespace=t_args)
else:
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
args = argparser.parse_args()
main(args)