Merge branch 'main' into fix_simplify_freezing_text_encoder_layers

This commit is contained in:
Victor Hall 2023-06-18 00:53:51 -04:00 committed by GitHub
commit a916934bb8
No known key found for this signature in database
6 changed files with 89 additions and 38 deletions

View File

@ -108,19 +108,35 @@ class DataLoaderMultiAspect():
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
bucket_key = (image_caption_pair.batch_id,
def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None):
bucket_key = (image.batch_id if batch_id_override is None else batch_id_override,
if bucket_key not in buckets:
buckets[bucket_key] = []
# handle runts by randomly duplicating items
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
runt_count = len(bucket_contents) % batch_size
if runt_count == 0:
runts = bucket_contents[-runt_count:]
del bucket_contents[-runt_count:]
for r in runts:
add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID)
if len(bucket_contents) == 0:
del buckets[key]
# handle remaining runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
@ -132,9 +148,21 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
# at this point items have a partially deterministic order
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
for k,v in items_by_batch_id.items()}
# paranoia: verify that this hasn't fucked up the aspect ratio batching
for items in items_by_batch_id.values():
batches = chunk(items, chunk_size=batch_size)
for batch in batches:
target_wh = batch[0].target_wh
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
@ -225,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
if len(l) == 0:
return []
# chunk by effective batch size
chunks = chunk(l, chunk_size)

View File

@ -97,9 +97,13 @@ class EveryDreamValidator:
self.config.update({'manual_data_root': self.config['val_data_root']})
if self.config.get('val_split_mode') == 'manual':
if 'manual_data_root' not in self.config:
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
manual_data_root = self.config.get('manual_data_root')
if manual_data_root is not None:
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
if len(self.config['extra_manual_datasets']) == 0:
raise ValueError("Error in validation config .json: 'manual' validation requested but no "
"'manual_data_root' or 'extra_manual_datasets'")
if 'val_split_proportion' in self.config:
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
@ -167,15 +171,22 @@ class EveryDreamValidator:
def do_validation(self, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
mean_loss_accumulator = 0
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(,
mean_loss_accumulator += mean_loss
# log combine loss to loss/_all_val_combined
if len(self.validation_datasets) > 1:
total_mean_loss = mean_loss_accumulator / len(self.validation_datasets)
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:

View File

@ -38,5 +38,6 @@
"text_encoder_freezing": {
"unfreeze_last_n_layers": null
"apply_grad_scaler_step_tweaks": true

View File

@ -56,6 +56,7 @@ class EveryDreamOptimizer():
self.grad_accum = args.grad_accum
self.clip_grad_norm = args.clip_grad_norm
self.apply_grad_scaler_step_tweaks = optimizer_config.get("apply_grad_scaler_step_tweaks", True)
self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder)
self.unet_params = unet.parameters()
@ -103,7 +104,8 @@ class EveryDreamOptimizer():
for scheduler in self.lr_schedulers:
if self.apply_grad_scaler_step_tweaks:
def _zero_grad(self, set_to_none=False):
for optimizer in self.optimizers:

View File

@ -390,7 +390,8 @@ def main(args):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
Save the model to disk
@ -412,21 +413,22 @@ def main(args):
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
if save_ckpt:
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
half = not save_full_precision" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if save_optimizer_flag:" Saving optimizer state to {save_path}")
@ -616,7 +618,7 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -731,6 +733,9 @@ def main(args):
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
# Pre-train validation to establish a starting point on the loss graph
if validator:
@ -816,13 +821,13 @@ def main(args):
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
last_epoch_saved_time = time.time()"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
del batch
global_step += 1
@ -846,9 +851,9 @@ def main(args):
# end of epoch
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -857,8 +862,8 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
raise ex"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -906,6 +911,8 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)

View File

@ -176,7 +176,7 @@ class SampleGenerator:
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self.sample_requests = self._make_random_caption_sample_requests()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):