From c82664b3f324b535516a4ea83b4b761bcc18a451 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 00:13:43 +0100 Subject: [PATCH 01/13] add text encoder LR setting to optimizer.json --- optimizer.json | 6 ++++-- train.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/optimizer.json b/optimizer.json index 03c64e9..a840b38 100644 --- a/optimizer.json +++ b/optimizer.json @@ -5,11 +5,13 @@ "lr": "learning rate, if null wil use CLI or main JSON config value", "betas": "exponential decay rates for the moment estimates", "epsilon": "value added to denominator for numerical stability, unused for lion", - "weight_decay": "weight decay (L2 penalty)" + "weight_decay": "weight decay (L2 penalty)", + "text_encoder_lr_scale": "if set, scale the text encoder's LR by this much relative to the unet LR" }, "optimizer": "adamw8bit", "lr": 1e-6, "betas": [0.9, 0.999], "epsilon": 1e-8, - "weight_decay": 0.010 + "weight_decay": 0.010, + "text_encoder_lr_scale": 1.0 } diff --git a/train.py b/train.py index a861771..8289f10 100644 --- a/train.py +++ b/train.py @@ -477,16 +477,6 @@ def main(args): else: text_encoder = text_encoder.to(device, dtype=torch.float32) - if args.disable_textenc_training: - logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") - params_to_train = itertools.chain(unet.parameters()) - elif args.disable_unet_training: - logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}") - params_to_train = itertools.chain(text_encoder.parameters()) - else: - logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") - params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) - optimizer_config = None optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json" if os.path.exists(os.path.join(os.curdir, optimizer_config_path)): @@ -514,6 +504,7 @@ def main(args): default_lr = 1e-6 curr_lr = args.lr + text_encoder_lr_scale = 1.0 if optimizer_config is not None: betas = optimizer_config["betas"] @@ -524,12 +515,30 @@ def main(args): if args.lr is not None: curr_lr = args.lr logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}") + + text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale) + if text_encoder_lr_scale != 1.0: + print(f" * Using text encoder LR scale {text_encoder_lr_scale}") + logging.info(f" * Loaded optimizer args from {optimizer_config_path} *") if curr_lr is None: curr_lr = default_lr logging.warning(f"No LR setting found, defaulting to {default_lr}") + text_encoder_lr = curr_lr * text_encoder_lr_scale + + if args.disable_textenc_training: + logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") + params_to_train = itertools.chain(unet.parameters()) + elif args.disable_unet_training: + logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}") + params_to_train = itertools.chain(text_encoder.parameters()) + else: + logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") + params_to_train = [{'params': unet.parameters()}, + {'params': text_encoder.parameters(), 'lr': text_encoder_lr}] + if optimizer_name: if optimizer_name == "lion": from lion_pytorch import Lion @@ -802,8 +811,15 @@ def main(args): curr_lr = lr_scheduler.get_last_lr()[0] loss_local = sum(loss_log_step) / len(loss_log_step) loss_log_step = [] - logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} - log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) + logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} + if args.disable_textenc_training or args.disable_unet_training: + log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step) + else: + curr_text_encoder_lr = lr_scheduler.get_last_lr()[1] + log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={ + 'unet': curr_lr, + 'text encoder': curr_text_encoder_lr + }, global_step = global_step) log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step) sum_img = sum(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step) From 8100e4215955174a52489db958c33b9f7ad6818c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 13:03:50 +0100 Subject: [PATCH 02/13] fix issues and improve sample generator --- train.py | 57 ++++++++++++++++++++++----------------- utils/sample_generator.py | 22 ++++++++++----- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/train.py b/train.py index 8289f10..6c5d6b6 100644 --- a/train.py +++ b/train.py @@ -125,12 +125,12 @@ def setup_local_logger(args): return datetimestamp -def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): +def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr): """ logs the optimizer settings """ logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") - logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") + logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") def save_optimizer(optimizer: torch.optim.Optimizer, path: str): """ @@ -526,7 +526,7 @@ def main(args): curr_lr = default_lr logging.warning(f"No LR setting found, defaulting to {default_lr}") - text_encoder_lr = curr_lr * text_encoder_lr_scale + curr_text_encoder_lr = curr_lr * text_encoder_lr_scale if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") @@ -537,7 +537,7 @@ def main(args): else: logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") params_to_train = [{'params': unet.parameters()}, - {'params': text_encoder.parameters(), 'lr': text_encoder_lr}] + {'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}] if optimizer_name: if optimizer_name == "lion": @@ -565,7 +565,7 @@ def main(args): amsgrad=False, ) - log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr) + log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr) image_train_items = resolve_image_train_items(args, log_folder) @@ -618,6 +618,7 @@ def main(args): default_resolution=args.resolution, default_seed=args.seed, config_file_path=args.sample_prompts, batch_size=max(1,args.batch_size//2), + default_sample_steps=args.sample_steps, use_xformers=is_xformers_available() and not args.disable_xformers) """ @@ -751,11 +752,32 @@ def main(args): return model_pred, target + def generate_samples(global_step: int, batch): + with isolate_rng(): + sample_generator.reload_config() + sample_generator.update_random_captions(batch["captions"]) + inference_pipe = sample_generator.create_inference_pipe(unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + diffusers_scheduler_config=reference_scheduler.config + ).to(device) + sample_generator.generate_samples(inference_pipe, global_step) + + del inference_pipe + gc.collect() + torch.cuda.empty_cache() + # Pre-train validation to establish a starting point on the loss graph if validator: validator.do_validation_if_appropriate(epoch=0, global_step=0, get_model_prediction_and_target_callable=get_model_prediction_and_target) + # the sample generator might be configured to generate samples before step 0 + if sample_generator.generate_pretrain_samples: + _, batch = next(enumerate(train_dataloader)) + generate_samples(global_step=0, batch=batch) + try: write_batch_schedule(args, log_folder, train_batch, epoch = 0) @@ -813,13 +835,11 @@ def main(args): loss_log_step = [] logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} if args.disable_textenc_training or args.disable_unet_training: - log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step) + log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) else: + log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step) curr_text_encoder_lr = lr_scheduler.get_last_lr()[1] - log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={ - 'unet': curr_lr, - 'text encoder': curr_text_encoder_lr - }, global_step = global_step) + log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step) log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step) sum_img = sum(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step) @@ -830,21 +850,8 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() - if (global_step + 1) % args.sample_steps == 0: - with isolate_rng(): - sample_generator.reload_config() - sample_generator.update_random_captions(batch["captions"]) - inference_pipe = sample_generator.create_inference_pipe(unet=unet, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - diffusers_scheduler_config=reference_scheduler.config - ).to(device) - sample_generator.generate_samples(inference_pipe, global_step) - - del inference_pipe - gc.collect() - torch.cuda.empty_cache() + if (global_step + 1) % sample_generator.sample_steps == 0: + generate_samples(global_step=global_step, batch=batch) min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 diff --git a/utils/sample_generator.py b/utils/sample_generator.py index a599e20..049a628 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -73,6 +73,7 @@ class SampleGenerator: config_file_path: str, batch_size: int, default_seed: int, + default_sample_steps: int, use_xformers: bool): self.log_folder = log_folder self.log_writer = log_writer @@ -80,10 +81,13 @@ class SampleGenerator: self.config_file_path = config_file_path self.use_xformers = use_xformers self.show_progress_bars = False + self.generate_pretrain_samples = False self.default_resolution = default_resolution self.default_seed = default_seed + self.sample_steps = default_sample_steps + self.sample_requests = None self.reload_config() print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") if not os.path.exists(f"{log_folder}/samples/"): @@ -102,8 +106,12 @@ class SampleGenerator: logging.warning( f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}") logging.warning( - f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.") - self.sample_requests = self._make_random_caption_sample_requests() + f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated." + ) + if self.sample_requests == None: + logging.warning( + f" Will generate samples from random training image captions until the problem is fixed.") + self.sample_requests = self._make_random_caption_sample_requests() def update_random_captions(self, possible_captions: list[str]): random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption] @@ -139,9 +147,11 @@ class SampleGenerator: self.scheduler = config.get('scheduler', self.scheduler) self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps) self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) - sample_requests_json = config.get('samples', None) - if sample_requests_json is None: - self.sample_requests = [] + self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples) + self.sample_steps = config.get('sample_steps', self.sample_steps) + sample_requests_config = config.get('samples', None) + if sample_requests_config is None: + self.sample_requests = self._make_random_caption_sample_requests() else: default_seed = config.get('seed', self.default_seed) default_size = (self.default_resolution, self.default_resolution) @@ -150,7 +160,7 @@ class SampleGenerator: seed=p.get('seed', default_seed), size=tuple(p.get('size', default_size)), wants_random_caption=p.get('random_caption', False) - ) for p in sample_requests_json] + ) for p in sample_requests_config] if len(self.sample_requests) == 0: self._make_random_caption_sample_requests() From 17c61db5ca09842119b30ee273b59c4205a09efe Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 14:50:17 +0100 Subject: [PATCH 03/13] wip aspect ratio samples --- utils/sample_generator.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 049a628..2583b7a 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -155,10 +155,22 @@ class SampleGenerator: else: default_seed = config.get('seed', self.default_seed) default_size = (self.default_resolution, self.default_resolution) + #def make_size_from_aspect_ratio(aspect_ratio): + # if aspect_ratio is None: + # return None + # target_pixel_count = self.default_resolution * self.default_resolution + # w_ratio = aspect_ratio + # h_ratio = 1/w_ratio + # pixels_per_ratio_unit = target_pixel_count/(w_ratio + h_ratio) + # w = round(w_ratio*pixels_per_ratio_unit / 64) * 64 + # h = round(h_ratio*pixels_per_ratio_unit / 64) * 64 + # return [w,h] + + self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), negative_prompt=p.get('negative_prompt', ''), seed=p.get('seed', default_seed), - size=tuple(p.get('size', default_size)), + size=p.get('size', default_size), wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_config] if len(self.sample_requests) == 0: From 61558be2ae737eb4d93738dced08a9bf205d2e9c Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 2 Mar 2023 18:29:28 +0100 Subject: [PATCH 04/13] logging and progress bar improvements --- data/every_dream_validation.py | 2 +- train.py | 9 +++++++-- utils/sample_generator.py | 19 ++++++++++++------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 95f5afd..2793ae5 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -105,7 +105,7 @@ class EveryDreamValidator: [Any, Any], tuple[torch.Tensor, torch.Tensor]]): with torch.no_grad(), isolate_rng(): loss_validation_epoch = [] - steps_pbar = tqdm(range(len(dataloader)), position=1) + steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}") for step, batch in enumerate(dataloader): diff --git a/train.py b/train.py index 6c5d6b6..d5eeff0 100644 --- a/train.py +++ b/train.py @@ -691,7 +691,7 @@ def main(args): ) logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") - epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) + epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_times = [] @@ -754,7 +754,12 @@ def main(args): def generate_samples(global_step: int, batch): with isolate_rng(): + prev_sample_steps = sample_generator.sample_steps sample_generator.reload_config() + if prev_sample_steps != sample_generator.sample_steps: + next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps + print(f" * SampleGenerator config changed, now generating images samples every " + + f"{sample_generator.sample_steps} training steps (next={next_sample_step})") sample_generator.update_random_captions(batch["captions"]) inference_pipe = sample_generator.create_inference_pipe(unet=unet, text_encoder=text_encoder, @@ -787,7 +792,7 @@ def main(args): images_per_sec_log_step = [] epoch_len = math.ceil(len(train_batch) / args.batch_size) - steps_pbar = tqdm(range(epoch_len), position=1) + steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") for step, batch in enumerate(train_dataloader): diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 049a628..c9c6ee6 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep from torch.cuda.amp import autocast from torch.utils.tensorboard import SummaryWriter from torchvision import transforms +from tqdm.auto import tqdm def clean_filename(filename): @@ -89,7 +90,7 @@ class SampleGenerator: self.sample_requests = None self.reload_config() - print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") + print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps") if not os.path.exists(f"{log_folder}/samples/"): os.makedirs(f"{log_folder}/samples/") @@ -169,9 +170,7 @@ class SampleGenerator: """ generates samples at different cfg scales and saves them to disk """ - logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") - - pipe.set_progress_bar_config(disable=(not self.show_progress_bars)) + disable_progress_bars = not self.show_progress_bars try: font = ImageFont.truetype(font="arial.ttf", size=20) @@ -183,10 +182,13 @@ class SampleGenerator: batch: list[SampleRequest] def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: return a.size == b.size - for batch in chunk_list(self.sample_requests, self.batch_size, - compatibility_test=sample_compatibility_test): - #print("batch: ", batch) + batches = list(chunk_list(self.sample_requests, self.batch_size, + compatibility_test=sample_compatibility_test)) + pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False, + desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") + for batch in batches: prompts = [p.prompt for p in batch] + pbar.set_postfix(postfix={'prompts': prompts}) negative_prompts = [p.negative_prompt for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] @@ -196,6 +198,8 @@ class SampleGenerator: batch_images = [] for cfg in self.cfgs: + pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False, + desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}") images = pipe(prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=self.num_inference_steps, @@ -257,6 +261,7 @@ class SampleGenerator: del tfimage del batch_images + pbar.update(1) @torch.no_grad() def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict): From fe0083877ff330cf3445cbe69d8ebd3a1031ac66 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:16:21 +0100 Subject: [PATCH 05/13] add aspect_ratio arg to sample generation --- utils/sample_generator.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index b5a6510..d4da54e 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -53,6 +53,15 @@ def chunk_list(l: list, batch_size: int, yield b[i:i + batch_size] +def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]: + sizes = [] + target_pixel_count = default_resolution * default_resolution + for w in range(256, 1024, 64): + for h in range(256, 1024, 64): + if abs((w * h) - target_pixel_count) <= 128 * 64: + sizes.append((w, h)) + best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1])))) + return best_size class SampleGenerator: @@ -155,23 +164,11 @@ class SampleGenerator: self.sample_requests = self._make_random_caption_sample_requests() else: default_seed = config.get('seed', self.default_seed) - default_size = (self.default_resolution, self.default_resolution) - #def make_size_from_aspect_ratio(aspect_ratio): - # if aspect_ratio is None: - # return None - # target_pixel_count = self.default_resolution * self.default_resolution - # w_ratio = aspect_ratio - # h_ratio = 1/w_ratio - # pixels_per_ratio_unit = target_pixel_count/(w_ratio + h_ratio) - # w = round(w_ratio*pixels_per_ratio_unit / 64) * 64 - # h = round(h_ratio*pixels_per_ratio_unit / 64) * 64 - # return [w,h] - - self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), negative_prompt=p.get('negative_prompt', ''), seed=p.get('seed', default_seed), - size=p.get('size', default_size), + size=tuple(p.get('size', None) or + get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)), wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_config] if len(self.sample_requests) == 0: @@ -200,7 +197,6 @@ class SampleGenerator: desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") for batch in batches: prompts = [p.prompt for p in batch] - pbar.set_postfix(postfix={'prompts': prompts}) negative_prompts = [p.negative_prompt for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] From ae281976caa52ba30679eaa943fcc733db77c6de Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:36:00 +0100 Subject: [PATCH 06/13] documentation of new text encoder LR and aspect_ratio settings --- doc/LOGGING.md | 5 +++-- doc/OPTIMIZER.md | 2 ++ optimizer.json | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/LOGGING.md b/doc/LOGGING.md index 441b673..9c76e71 100644 --- a/doc/LOGGING.md +++ b/doc/LOGGING.md @@ -35,7 +35,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w }, { "prompt": "a photograph of ted bennet riding a bicycle", - "seed": -1 + "seed": -1, + "aspect_ratio": 1.77778 }, { "random_caption": true, @@ -47,7 +48,7 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process. -Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time. +Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time. ## LR diff --git a/doc/OPTIMIZER.md b/doc/OPTIMIZER.md index dc2b598..8c13a78 100644 --- a/doc/OPTIMIZER.md +++ b/doc/OPTIMIZER.md @@ -34,6 +34,8 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg. +The text encoder LR can run at a different value to the U-net LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to have the text encoder LR to 50% of the U-net LR, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and U-net are trained with the same LR. + Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here. Note `lion` does not use epsilon. \ No newline at end of file diff --git a/optimizer.json b/optimizer.json index a840b38..574aa31 100644 --- a/optimizer.json +++ b/optimizer.json @@ -6,7 +6,7 @@ "betas": "exponential decay rates for the moment estimates", "epsilon": "value added to denominator for numerical stability, unused for lion", "weight_decay": "weight decay (L2 penalty)", - "text_encoder_lr_scale": "if set, scale the text encoder's LR by this much relative to the unet LR" + "text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`." }, "optimizer": "adamw8bit", "lr": 1e-6, From 97a8a4977363981600e86bbb12c4918d4d3b6168 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:36:32 +0100 Subject: [PATCH 07/13] don't log separate text encoder LR if it's the same as unet LR --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index d5eeff0..f26b400 100644 --- a/train.py +++ b/train.py @@ -839,7 +839,7 @@ def main(args): loss_local = sum(loss_log_step) / len(loss_log_step) loss_log_step = [] logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} - if args.disable_textenc_training or args.disable_unet_training: + if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1: log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) else: log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step) From bb1fbb329aa1aae8519268df3778dc541e11cb0c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:45:22 +0100 Subject: [PATCH 08/13] tweak docs --- doc/OPTIMIZER.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/OPTIMIZER.md b/doc/OPTIMIZER.md index 8c13a78..3db3f45 100644 --- a/doc/OPTIMIZER.md +++ b/doc/OPTIMIZER.md @@ -34,7 +34,7 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg. -The text encoder LR can run at a different value to the U-net LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to have the text encoder LR to 50% of the U-net LR, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and U-net are trained with the same LR. +The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR. Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here. From e2fd45737d2cc0fabbad93686eb6dd06beda63e4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:52:26 +0100 Subject: [PATCH 09/13] overwrite args.seed with the actual seed if -1 is passed (so it appears in tensorboard) also improve logging when unet training is disabled --- data/resolver.py | 7 ++----- train.py | 7 ++++++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/data/resolver.py b/data/resolver.py index e66a3b8..b31043f 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -1,7 +1,6 @@ import json import logging import os -import random import typing import zipfile import argparse @@ -18,8 +17,7 @@ class DataResolver: """ self.aspects = args.aspects self.flip_p = args.flip_p - self.seed = args.seed - + def image_train_items(self, data_root: str) -> list[ImageTrainItem]: """ Get the list of `ImageTrainItem` for the given data root. @@ -116,8 +114,7 @@ class DirectoryResolver(DataResolver): image_paths = list(DirectoryResolver.recurse_data_root(data_root)) items = [] multipliers = {} - randomizer = random.Random(self.seed) - + for pathname in tqdm.tqdm(image_paths): current_dir = os.path.dirname(pathname) diff --git a/train.py b/train.py index f26b400..ec9b148 100644 --- a/train.py +++ b/train.py @@ -363,7 +363,9 @@ def main(args): else: from tqdm.auto import tqdm - seed = args.seed if args.seed != -1 else random.randint(0, 2**30) + if args.seed == -1: + args.seed = random.randint(0, 2**30) + seed = args.seed logging.info(f" Seed: {seed}") set_seed(seed) if torch.cuda.is_available(): @@ -533,6 +535,9 @@ def main(args): params_to_train = itertools.chain(unet.parameters()) elif args.disable_unet_training: logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}") + if text_encoder_lr_scale != 1: + logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the " + f"Unet LR {curr_lr} for the text encoder instead.") params_to_train = itertools.chain(text_encoder.parameters()) else: logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") From 15187ae2e208f807768cffb539272acd7f91046a Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 22:53:59 +0100 Subject: [PATCH 10/13] fix log leaking color --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index ec9b148..8b6b2f9 100644 --- a/train.py +++ b/train.py @@ -537,7 +537,7 @@ def main(args): logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}") if text_encoder_lr_scale != 1: logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the " - f"Unet LR {curr_lr} for the text encoder instead.") + f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}") params_to_train = itertools.chain(text_encoder.parameters()) else: logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") From c4d37862ba9be6c56efcbf200a2d3f8365a6d246 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 2 Mar 2023 23:12:47 +0100 Subject: [PATCH 11/13] document new sample generator params --- doc/LOGGING.md | 6 +++++- utils/sample_generator.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/doc/LOGGING.md b/doc/LOGGING.md index 9c76e71..ae358a9 100644 --- a/doc/LOGGING.md +++ b/doc/LOGGING.md @@ -28,6 +28,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w "scheduler": "dpm++", "num_inference_steps": 15, "show_progress_bars": true, + "generate_samples_every_n_steps": 200, + "generate_pretrain_samples": true, "samples": [ { "prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background", @@ -46,7 +48,9 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w } ``` -At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process. +At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. If you want to see sample progress bars you can set `show_progress_bars` to `true`. To generate a batch of samples before training begins, set `generate_pretrain_samples` to true. + +Finally, you can override the `sample_steps` set in the main configuration .json file (or CLI) by setting `generate_samples_every_n_steps`. This value is read every time samples are updated, so if you initially pass `--sample_steps 200` and then later on you edit your `sample_prompts.json` file to add `"generate_samples_every_n_steps": 100`, after the next set of samples is generated you will start seeing new sets of image samples every 100 steps instead of only every 200 steps. Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time. diff --git a/utils/sample_generator.py b/utils/sample_generator.py index d4da54e..01d8f00 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -158,7 +158,7 @@ class SampleGenerator: self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps) self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples) - self.sample_steps = config.get('sample_steps', self.sample_steps) + self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps) sample_requests_config = config.get('samples', None) if sample_requests_config is None: self.sample_requests = self._make_random_caption_sample_requests() From 644c9e6b2a8b3a8e72dd9e9f4e2e52ae1d6c7a92 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 3 Mar 2023 09:52:44 +0100 Subject: [PATCH 12/13] log to logging.info instead of stdout --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 8b6b2f9..4d06ce1 100644 --- a/train.py +++ b/train.py @@ -520,7 +520,7 @@ def main(args): text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale) if text_encoder_lr_scale != 1.0: - print(f" * Using text encoder LR scale {text_encoder_lr_scale}") + logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}") logging.info(f" * Loaded optimizer args from {optimizer_config_path} *") From 29c93eca03ae5d07179c8370b6249100a7fec8c1 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 3 Mar 2023 10:50:48 +0100 Subject: [PATCH 13/13] if progress bars are disabled, log a short message instead --- utils/sample_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 01d8f00..f053dfa 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -186,6 +186,9 @@ class SampleGenerator: except: font = ImageFont.load_default() + if not self.show_progress_bars: + print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts") + sample_index = 0 with autocast(): batch: list[SampleRequest]