From 3c921dbaa236221ca7cadb555ffc06dd40fd9327 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sun, 8 Jan 2023 18:52:39 -0500 Subject: [PATCH] chaining and more lowers resolutions --- README.md | 4 +- chain.bat | 3 + chain0.json | 38 ++++++++++++ chain1.json | 38 ++++++++++++ chain2.json | 38 ++++++++++++ data/aspects.py | 53 +++++++++++----- train.py | 159 +++++++++++++++++++++++++++++++++++------------- 7 files changed, 276 insertions(+), 57 deletions(-) create mode 100644 chain.bat create mode 100644 chain0.json create mode 100644 chain1.json create mode 100644 chain2.json diff --git a/README.md b/README.md index ba5ef63..b2d6018 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Welcome to v2.0 of EveryDream trainer! Now with more diffusers and even more fea Please join us on Discord! https://discord.gg/uheqxU6sXN -If you find this tool useful, please consider subscribing to the project on Patreon or buy me a Ko-fi. +If you find this tool useful, please consider subscribing to the project on [Patreon](https://www.patreon.com/everydream) or a one-time donation at [Ko-fi](https://ko-fi.com/everydream). ## Video tutorials @@ -32,3 +32,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter [Logging](doc/LOGGING.md) [Advanced Tweaking](doc/ATWEAKING.md) + +[Chaining training sessions](doc/CHAINING.md) diff --git a/chain.bat b/chain.bat new file mode 100644 index 0000000..a4ff929 --- /dev/null +++ b/chain.bat @@ -0,0 +1,3 @@ +python train.py --config chain0.json +python train.py --config chain1.json +python train.py --config chain2.json \ No newline at end of file diff --git a/chain0.json b/chain0.json new file mode 100644 index 0000000..8402db7 --- /dev/null +++ b/chain0.json @@ -0,0 +1,38 @@ +{ + "amp": false, + "batch_size": 12, + "ckpt_every_n_minutes": null, + "clip_grad_norm": null, + "clip_skip": 0, + "cond_dropout": 0.00, + "data_root": "R:\\everydream-trainer\\training_samples\\ff7r", + "disable_textenc_training": false, + "disable_xformers": true, + "flip_p": 0.0, + "ed1_mode": true, + "gpuid": 0, + "gradient_checkpointing": true, + "grad_accum": 1, + "logdir": "logs", + "log_step": 25, + "lowvram": false, + "lr": 2.5e-6, + "lr_decay_steps": 0, + "lr_scheduler": "constant", + "lr_warmup_steps": null, + "max_epochs": 15, + "project_name": "myproj_ch0", + "resolution": 384, + "resume_ckpt": "sd_v1-5_vae", + "sample_prompts": "sample_prompts.txt", + "sample_steps": 300, + "save_ckpt_dir": null, + "save_every_n_epochs": 99, + "save_optimizer": false, + "scale_lr": false, + "seed": -1, + "shuffle_tags": false, + "useadam8bit": true, + "wandb": false, + "write_schedule": true +} \ No newline at end of file diff --git a/chain1.json b/chain1.json new file mode 100644 index 0000000..7378f6e --- /dev/null +++ b/chain1.json @@ -0,0 +1,38 @@ +{ + "amp": false, + "batch_size": 7, + "ckpt_every_n_minutes": null, + "clip_grad_norm": null, + "clip_skip": 0, + "cond_dropout": 0.05, + "data_root": "R:\\everydream-trainer\\training_samples\\ff7r", + "disable_textenc_training": false, + "disable_xformers": true, + "flip_p": 0.0, + "ed1_mode": true, + "gpuid": 0, + "gradient_checkpointing": true, + "grad_accum": 1, + "logdir": "logs", + "log_step": 25, + "lowvram": false, + "lr": 1.0e-6, + "lr_decay_steps": 0, + "lr_scheduler": "constant", + "lr_warmup_steps": null, + "max_epochs": 10, + "project_name": "myproj_ch0", + "resolution": 512, + "resume_ckpt": "findlast", + "sample_prompts": "sample_prompts.txt", + "sample_steps": 300, + "save_ckpt_dir": null, + "save_every_n_epochs": 5, + "save_optimizer": false, + "scale_lr": false, + "seed": -1, + "shuffle_tags": false, + "useadam8bit": true, + "wandb": false, + "write_schedule": true +} \ No newline at end of file diff --git a/chain2.json b/chain2.json new file mode 100644 index 0000000..9948c5d --- /dev/null +++ b/chain2.json @@ -0,0 +1,38 @@ +{ + "amp": false, + "batch_size": 2, + "ckpt_every_n_minutes": null, + "clip_grad_norm": null, + "clip_skip": 0, + "cond_dropout": 0.08, + "data_root": "R:\\everydream-trainer\\training_samples\\ff7r", + "disable_textenc_training": true, + "disable_xformers": true, + "flip_p": 0.0, + "ed1_mode": true, + "gpuid": 0, + "gradient_checkpointing": true, + "grad_accum": 5, + "logdir": "logs", + "log_step": 25, + "lowvram": false, + "lr": 1.5e-6, + "lr_decay_steps": 0, + "lr_scheduler": "constant", + "lr_warmup_steps": null, + "max_epochs": 10, + "project_name": "myproj_ch0", + "resolution": 640, + "resume_ckpt": "findlast", + "sample_prompts": "sample_prompts.txt", + "sample_steps": 300, + "save_ckpt_dir": null, + "save_every_n_epochs": 5, + "save_optimizer": false, + "scale_lr": false, + "seed": -1, + "shuffle_tags": false, + "useadam8bit": true, + "wandb": false, + "write_schedule": true +} \ No newline at end of file diff --git a/data/aspects.py b/data/aspects.py index 1a30697..52a2ce3 100644 --- a/data/aspects.py +++ b/data/aspects.py @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -ASPECTS11 = [[1152,1152], # 1327104 1:1 +ASPECTS_1152 = [[1152,1152], # 1327104 1:1 #[1216,1088],[1088,1216], # 1323008 1.118:1 [1280,1024],[1024,1280], # 1310720 1.25:1 [1344,960],[960,1344], # 1290240 1.4:1 @@ -25,7 +25,7 @@ ASPECTS11 = [[1152,1152], # 1327104 1:1 [2304,576],[576,2304], # 1327104 4:1 ] -ASPECTS10 = [[1088,1088], # 1183744 1:1 +ASPECTS_1088 = [[1088,1088], # 1183744 1:1 [1152,1024],[1024,1152], # 1167360 1.267:1 [1216,896],[896,1216], # 1146880 1.429:1 [1408,832],[832,1408], # 1171456 1.692:1 @@ -36,7 +36,7 @@ ASPECTS10 = [[1088,1088], # 1183744 1:1 [2304,512],[512,2304], # 1179648 4.5:1 ] -ASPECTS9 = [[1024,1024], # 1048576 1:1 +ASPECTS_1024 = [[1024,1024], # 1048576 1:1 #[1088,960],[960,1088], # 1044480 1.125:1 [1152,896],[896,1152], # 1032192 1.286:1 [1216,832],[832,1216], # 1011712 1.462:1 @@ -47,7 +47,7 @@ ASPECTS9 = [[1024,1024], # 1048576 1:1 [2048,512],[512,2048], # 1048576 4:1 ] -ASPECTS8 = [[960,960], # 921600 1:1 +ASPECTS_960 = [[960,960], # 921600 1:1 [1024,896],[896,1024], # 917504 1.143:1 [1088,832],[832,1088], # 905216 1.308:1 [1152,768],[768,1152], # 884736 1.5:1 @@ -59,7 +59,7 @@ ASPECTS8 = [[960,960], # 921600 1:1 [2048,448],[448,2048], # 917504 4.714:1 ] -ASPECTS7 = [[896,896], # 802816 1:1 +ASPECTS_896 = [[896,896], # 802816 1:1 [960,832],[832,960], # 798720 1.153:1 [1024,768],[768,1024], # 786432 1.333:1 [1088,704],[704,1088], # 765952 1.545:1 @@ -69,7 +69,7 @@ ASPECTS7 = [[896,896], # 802816 1:1 [1792,448],[448,1792], # 802816 4:1 ] -ASPECTS6 = [[832,832], # 692224 1:1 +ASPECTS_832 = [[832,832], # 692224 1:1 [896,768],[768,896], # 688128 1.167:1 [960,704],[704,960], # 675840 1.364:1 #[960,640],[640,960], # 614400 1.5:1 @@ -82,7 +82,7 @@ ASPECTS6 = [[832,832], # 692224 1:1 [1600,384],[384,1600], # 614400 4.167:1 ] -ASPECTS5 = [[768,768], # 589824 1:1 +ASPECTS_768 = [[768,768], # 589824 1:1 [832,704],[704,832], # 585728 1.181:1 [896,640],[640,896], # 573440 1.4:1 [960,576],[576,960], # 552960 1.6:1 @@ -96,7 +96,7 @@ ASPECTS5 = [[768,768], # 589824 1:1 [1472,320],[320,1472], # 470400 4.6:1 ] -ASPECTS4 = [[704,704], # 501,376 1:1 +ASPECTS_704 = [[704,704], # 501,376 1:1 [768,640],[640,768], # 491,520 1.2:1 [832,576],[576,832], # 458,752 1.444:1 #[896,512],[512,896], # 458,752 1.75:1 @@ -109,7 +109,7 @@ ASPECTS4 = [[704,704], # 501,376 1:1 [1280,320],[320,1280], # 409,600 4:1 ] -ASPECTS3 = [[640,640], # 409600 1:1 +ASPECTS_640 = [[640,640], # 409600 1:1 [704,576],[576,704], # 405504 1.25:1 [768,512],[512,768], # 393216 1.5:1 [832,448],[448,832], # 372736 1.857:1 @@ -119,7 +119,7 @@ ASPECTS3 = [[640,640], # 409600 1:1 [1280,320],[320,1280], # 409600 4:1 ] -ASPECTS2 = [[576,576], # 331776 1:1 +ASPECTS_576 = [[576,576], # 331776 1:1 [640,512],[512,640], # 327680 1.25:1 #[640,448],[448,640], # 286720 1.4286:1 [704,448],[448,704], # 314928 1.5625:1 @@ -130,7 +130,7 @@ ASPECTS2 = [[576,576], # 331776 1:1 #[1280,256],[256,1280], # 327680 5:1 ] -ASPECTS = [[512,512], # 262144 1:1 +ASPECTS_512 = [[512,512], # 262144 1:1 [576,448],[448,576], # 258048 1.29:1 [640,384],[384,640], # 245760 1.667:1 [768,320],[320,768], # 245760 2.4:1 @@ -140,14 +140,25 @@ ASPECTS = [[512,512], # 262144 1:1 [1024,256],[256,1024], # 245760 4:1 ] -ASPECTS0 = [[448,448], # 200704 1:1 +ASPECTS_448 = [[448,448], # 200704 1:1 [512,384],[384,512], # 196608 1.333:1 [640,320],[320,640], # 204800 2:1 [768,256],[256,768], # 196608 3:1 ] +ASPECTS_384 = [[384,384], # 147456 1:1 + [448,320],[320,448], # 143360 1.4:1 + [512,256],[256,512], # 131072 2:1 + [704,192],[192,704], # 135168 3.667:1 +] + +ASPECTS_256 = [[256,256], # 65536 1:1 + [384,192],[192,384], # 73728 2:1 + [512,128],[128,512], # 65536 4:1 +] + def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False): - if resolution < 512: + if resolution < 256: raise ValueError("Resolution must be at least 512") try: rounded_resolution = int(resolution / 64) * 64 @@ -164,4 +175,18 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False): raise e def __get_all_aspects(): - return [ASPECTS0, ASPECTS, ASPECTS2, ASPECTS3, ASPECTS4, ASPECTS5, ASPECTS6, ASPECTS7, ASPECTS8, ASPECTS9, ASPECTS10, ASPECTS11] \ No newline at end of file + return [ASPECTS_256, + ASPECTS_384, + ASPECTS_448, + ASPECTS_512, + ASPECTS_576, + ASPECTS_640, + ASPECTS_704, + ASPECTS_768, + ASPECTS_832, + ASPECTS_896, + ASPECTS_960, + ASPECTS_1024, + ASPECTS_1088, + ASPECTS_1152 + ] \ No newline at end of file diff --git a/train.py b/train.py index 9f0e53e..234aa8c 100644 --- a/train.py +++ b/train.py @@ -105,6 +105,7 @@ def setup_local_logger(args): configures logger with file and console logging, logs args, and returns the datestamp """ log_path = args.logdir + if not os.path.exists(log_path): os.makedirs(log_path) @@ -115,6 +116,7 @@ def setup_local_logger(args): f.write(f"{json_config}") logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log") + print(f" logging to {logfilename}") logging.basicConfig(filename=logfilename, level=logging.INFO, format="%(asctime)s %(message)s", @@ -122,6 +124,7 @@ def setup_local_logger(args): ) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + return datetimestamp def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon): @@ -190,30 +193,51 @@ def set_args_12gb(args): logging.info(" Overiding adam8bit to True") args.useadam8bit = True -def main(args): +def find_last_checkpoint(logdir): """ - Main entry point + Finds the last checkpoint in the logdir, recursively """ - log_time = setup_local_logger(args) - + last_ckpt = None + last_date = None + + for root, dirs, files in os.walk(logdir): + for file in files: + if os.path.basename(file) == "model_index.json": + curr_date = os.path.getmtime(os.path.join(root,file)) + + if last_date is None or curr_date > last_date: + last_date = curr_date + last_ckpt = root + + assert last_ckpt, f"Could not find last checkpoint in logdir: {logdir}" + assert "errored" not in last_ckpt, f"Found last checkpoint: {last_ckpt}, but it was errored, cancelling" + + print(f" {Fore.LIGHTCYAN_EX}Found last checkpoint: {last_ckpt}, resuming{Style.RESET_ALL}") + + return last_ckpt + +def setup_args(args): + """ + Sets defaults for missing args (possible if missing from json config) + Forces some args to be set based on others for compatibility reasons + """ + if args.resume_ckpt == "findlast": + logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}") + # find the last checkpoint in the logdir + args.resume_ckpt = find_last_checkpoint(args.logdir) + + if args.ed1_mode and not args.disable_xformers: + args.disable_xformers = True + logging.info(" ED1 mode: Overiding disable_xformers to True") + if args.lowvram: set_args_12gb(args) - seed = args.seed if args.seed != -1 else random.randint(0, 2**30) - set_seed(seed) - gpu = GPU() - device = torch.device(f"cuda:{args.gpuid}") - - torch.backends.cudnn.benchmark = False - - if args.ed1_mode: - args.disable_xformers = True - if not args.shuffle_tags: args.shuffle_tags = False args.clip_skip = max(min(4, args.clip_skip), 0) - + if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") args.ckpt_every_n_minutes = 20 @@ -231,16 +255,32 @@ def main(args): if args.cond_dropout > 0.26: logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}") - total_batch_size = args.batch_size * args.grad_accum - if args.grad_accum > 1: logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}") + total_batch_size = args.batch_size * args.grad_accum + if args.scale_lr is not None and args.scale_lr: tmp_lr = args.lr args.lr = args.lr * (total_batch_size**0.55) logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}") + return args + +def main(args): + """ + Main entry point + """ + log_time = setup_local_logger(args) + args = setup_args(args) + + seed = args.seed if args.seed != -1 else random.randint(0, 2**30) + set_seed(seed) + gpu = GPU() + device = torch.device(f"cuda:{args.gpuid}") + + torch.backends.cudnn.benchmark = True + log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") logging.info(f"Logging to {log_folder}") if not os.path.exists(log_folder): @@ -409,9 +449,12 @@ def main(args): default_lr = 3e-6 curr_lr = args.lr if args.lr is not None else default_lr + # vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) + # unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) + # text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) - unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) - text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) + unet = unet.to(device, dtype=torch.float32) + text_encoder = text_encoder.to(device, dtype=torch.float32) if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") @@ -537,15 +580,7 @@ def main(args): logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - # scaler = torch.cuda.amp.GradScaler( - # #enabled=False, - # enabled=True if args.amp else False, - # init_scale=2**1, - # growth_factor=1.000001, - # backoff_factor=0.9999999, - # growth_interval=50, - # ) - #logging.info(f" Grad scaler enabled: {scaler.is_enabled()}") + def collate_fn(batch): """ @@ -607,8 +642,22 @@ def main(args): #loss = torch.tensor(0.0, device=device, dtype=torch.float32) - try: + if args.amp: + #scaler = torch.cuda.amp.GradScaler() + scaler = torch.cuda.amp.GradScaler( + #enabled=False, + enabled=True, + init_scale=1024.0, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=50, + ) + logging.info(f" Grad scaler enabled: {scaler.is_enabled()}") + + loss_log_step = [] + try: for epoch in range(args.max_epochs): + loss_epoch = [] epoch_start_time = time.time() steps_pbar.reset() images_per_sec_log_step = [] @@ -619,8 +668,8 @@ def main(args): with torch.no_grad(): #with autocast(): pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) - with autocast(enabled=args.amp): - latents = vae.encode(pixel_values, return_dict=False) + #with autocast(enabled=args.amp): + latents = vae.encode(pixel_values, return_dict=False) del pixel_values latents = latents[0].sample() * 0.18215 @@ -650,8 +699,8 @@ def main(args): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") del noise, latents, cuda_caption - with autocast(enabled=args.amp): - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + #with autocast(enabled=args.amp): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): @@ -663,7 +712,10 @@ def main(args): torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) - loss.backward() + if args.amp: + scaler.scale(loss).backward() + else: + loss.backward() if batch["runt_size"] > 0: grad_scale = batch["runt_size"] / args.batch_size @@ -677,28 +729,37 @@ def main(args): param.grad *= grad_scale if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): - optimizer.step() + if args.amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() optimizer.zero_grad(set_to_none=True) lr_scheduler.step() - steps_pbar.set_postfix({"gs": global_step}) + loss_step = loss.detach().item() + + steps_pbar.set_postfix({"loss/step": loss_step},{"gs": global_step}) steps_pbar.update(1) - global_step += 1 images_per_sec = args.batch_size / (time.time() - step_start_time) images_per_sec_log_step.append(images_per_sec) + loss_log_step.append(loss_step) + loss_epoch.append(loss_step) + if (global_step + 1) % args.log_step == 0: curr_lr = lr_scheduler.get_last_lr()[0] - loss_local = loss.detach().item() - logs = {"loss/step": loss_local, "lr": curr_lr, "img/s": images_per_sec} - log_writer.add_scalar(tag="loss/step", scalar_value=loss_local, global_step=global_step) + loss_local = sum(loss_log_step) / len(loss_log_step) + loss_log_step = [] + logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) sum_img = sum(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step) images_per_sec_log_step = [] - #log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step) + if args.amp: + log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step) log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step) append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() @@ -732,7 +793,8 @@ def main(args): save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) - del loss, batch + del batch + global_step += 1 # end of step elapsed_epoch_time = (time.time() - epoch_start_time) / 60 @@ -742,6 +804,9 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch+1) + + loss_local = sum(loss_epoch) / len(loss_epoch) + log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) # end of epoch # end of training @@ -765,6 +830,14 @@ def main(args): logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") +def update_old_args(t_args): + """ + Update old args to new args to deal with json config loading and missing args for compatibility + """ + if not hasattr(t_args, "shuffle_tags"): + print(f" Config json is missing 'shuffle_tags'") + t_args.__dict__["shuffle_tags"] = False + if __name__ == "__main__": supported_resolutions = [448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] argparser = argparse.ArgumentParser(description="EveryDream2 Training options") @@ -776,6 +849,8 @@ if __name__ == "__main__": with open(args.config, 'rt') as f: t_args = argparse.Namespace() t_args.__dict__.update(json.load(f)) + update_old_args(t_args) # update args to support older configs + print(t_args.__dict__) args = argparser.parse_args(namespace=t_args) else: print("No config file specified, using command line args")