Merge branch 'main' into fix_simplify_freezing_text_encoder_layers

This commit is contained in:
Victor Hall 2023-06-18 00:53:51 -04:00 committed by GitHub
commit a916934bb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 38 deletions

View File

@ -108,19 +108,35 @@ class DataLoaderMultiAspect():
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
bucket_key = (image_caption_pair.batch_id,
image_caption_pair.target_wh[0],
image_caption_pair.target_wh[1])
def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None):
bucket_key = (image.batch_id if batch_id_override is None else batch_id_override,
image.target_wh[0],
image.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
buckets[bucket_key].append(image)
# handle runts by randomly duplicating items
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
add_image_to_appropriate_bucket(image_caption_pair)
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
runt_count = len(bucket_contents) % batch_size
if runt_count == 0:
continue
runts = bucket_contents[-runt_count:]
del bucket_contents[-runt_count:]
for r in runts:
add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID)
if len(bucket_contents) == 0:
del buckets[key]
# handle remaining runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
@ -132,9 +148,21 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
# at this point items have a partially deterministic order
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
for k,v in items_by_batch_id.items()}
# paranoia: verify that this hasn't fucked up the aspect ratio batching
for items in items_by_batch_id.values():
batches = chunk(items, chunk_size=batch_size)
for batch in batches:
target_wh = batch[0].target_wh
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)
@ -225,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
"""
if len(l) == 0:
return []
# chunk by effective batch size
chunks = chunk(l, chunk_size)

View File

@ -97,9 +97,13 @@ class EveryDreamValidator:
self.config.update({'manual_data_root': self.config['val_data_root']})
if self.config.get('val_split_mode') == 'manual':
if 'manual_data_root' not in self.config:
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
manual_data_root = self.config.get('manual_data_root')
if manual_data_root is not None:
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
else:
if len(self.config['extra_manual_datasets']) == 0:
raise ValueError("Error in validation config .json: 'manual' validation requested but no "
"'manual_data_root' or 'extra_manual_datasets'")
if 'val_split_proportion' in self.config:
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
@ -167,15 +171,22 @@ class EveryDreamValidator:
def do_validation(self, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
mean_loss_accumulator = 0
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
mean_loss_accumulator += mean_loss
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
# log combine loss to loss/_all_val_combined
if len(self.validation_datasets) > 1:
total_mean_loss = mean_loss_accumulator / len(self.validation_datasets)
self.log_writer.add_scalar(tag=f"loss/_all_val_combined",
scalar_value=total_mean_loss,
global_step=global_step)
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:

View File

@ -38,5 +38,6 @@
},
"text_encoder_freezing": {
"unfreeze_last_n_layers": null
}
},
"apply_grad_scaler_step_tweaks": true
}

View File

@ -56,6 +56,7 @@ class EveryDreamOptimizer():
self.grad_accum = args.grad_accum
self.clip_grad_norm = args.clip_grad_norm
self.apply_grad_scaler_step_tweaks = optimizer_config.get("apply_grad_scaler_step_tweaks", True)
self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder)
self.unet_params = unet.parameters()
@ -103,7 +104,8 @@ class EveryDreamOptimizer():
for scheduler in self.lr_schedulers:
scheduler.step()
self._update_grad_scaler(global_step)
if self.apply_grad_scaler_step_tweaks:
self._update_grad_scaler(global_step)
def _zero_grad(self, set_to_none=False):
for optimizer in self.optimizers:

View File

@ -390,7 +390,8 @@ def main(args):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
"""
Save the model to disk
"""
@ -412,21 +413,22 @@ def main(args):
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
if save_ckpt:
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
@ -616,7 +618,7 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -731,6 +733,9 @@ def main(args):
gc.collect()
torch.cuda.empty_cache()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation(global_step=0,
@ -816,13 +821,13 @@ def main(args):
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
del batch
global_step += 1
@ -846,9 +851,9 @@ def main(args):
# end of epoch
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -857,8 +862,8 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -906,6 +911,8 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)

View File

@ -176,7 +176,7 @@ class SampleGenerator:
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):