Merge branch 'main' into fix_simplify_freezing_text_encoder_layers
This commit is contained in:
commit
a916934bb8
|
@ -108,19 +108,35 @@ class DataLoaderMultiAspect():
|
|||
batch_size = self.batch_size
|
||||
grad_accum = self.grad_accum
|
||||
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
bucket_key = (image_caption_pair.batch_id,
|
||||
image_caption_pair.target_wh[0],
|
||||
image_caption_pair.target_wh[1])
|
||||
def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None):
|
||||
bucket_key = (image.batch_id if batch_id_override is None else batch_id_override,
|
||||
image.target_wh[0],
|
||||
image.target_wh[1])
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(image_caption_pair)
|
||||
buckets[bucket_key].append(image)
|
||||
|
||||
# handle runts by randomly duplicating items
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
add_image_to_appropriate_bucket(image_caption_pair)
|
||||
|
||||
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
|
||||
for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
|
||||
runt_count = len(bucket_contents) % batch_size
|
||||
if runt_count == 0:
|
||||
continue
|
||||
runts = bucket_contents[-runt_count:]
|
||||
del bucket_contents[-runt_count:]
|
||||
for r in runts:
|
||||
add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID)
|
||||
if len(bucket_contents) == 0:
|
||||
del buckets[key]
|
||||
|
||||
# handle remaining runts by randomly duplicating items
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
|
@ -132,9 +148,21 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
# at this point items have a partially deterministic order
|
||||
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
|
||||
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
|
||||
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
|
||||
for k,v in items_by_batch_id.items()}
|
||||
# paranoia: verify that this hasn't fucked up the aspect ratio batching
|
||||
for items in items_by_batch_id.values():
|
||||
batches = chunk(items, chunk_size=batch_size)
|
||||
for batch in batches:
|
||||
target_wh = batch[0].target_wh
|
||||
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
|
||||
|
||||
# handle batch_id
|
||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||
batch_size=batch_size,
|
||||
grad_accum=grad_accum)
|
||||
|
@ -225,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List
|
|||
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
|
||||
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
|
||||
"""
|
||||
if len(l) == 0:
|
||||
return []
|
||||
|
||||
# chunk by effective batch size
|
||||
chunks = chunk(l, chunk_size)
|
||||
|
|
|
@ -97,9 +97,13 @@ class EveryDreamValidator:
|
|||
self.config.update({'manual_data_root': self.config['val_data_root']})
|
||||
|
||||
if self.config.get('val_split_mode') == 'manual':
|
||||
if 'manual_data_root' not in self.config:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
manual_data_root = self.config.get('manual_data_root')
|
||||
if manual_data_root is not None:
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
else:
|
||||
if len(self.config['extra_manual_datasets']) == 0:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation requested but no "
|
||||
"'manual_data_root' or 'extra_manual_datasets'")
|
||||
|
||||
if 'val_split_proportion' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
|
||||
|
@ -167,15 +171,22 @@ class EveryDreamValidator:
|
|||
def do_validation(self, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
mean_loss_accumulator = 0
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
mean_loss_accumulator += mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
|
||||
# log combine loss to loss/_all_val_combined
|
||||
if len(self.validation_datasets) > 1:
|
||||
total_mean_loss = mean_loss_accumulator / len(self.validation_datasets)
|
||||
self.log_writer.add_scalar(tag=f"loss/_all_val_combined",
|
||||
scalar_value=total_mean_loss,
|
||||
global_step=global_step)
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
|
|
@ -38,5 +38,6 @@
|
|||
},
|
||||
"text_encoder_freezing": {
|
||||
"unfreeze_last_n_layers": null
|
||||
}
|
||||
},
|
||||
"apply_grad_scaler_step_tweaks": true
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ class EveryDreamOptimizer():
|
|||
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.apply_grad_scaler_step_tweaks = optimizer_config.get("apply_grad_scaler_step_tweaks", True)
|
||||
|
||||
self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder)
|
||||
self.unet_params = unet.parameters()
|
||||
|
@ -103,7 +104,8 @@ class EveryDreamOptimizer():
|
|||
for scheduler in self.lr_schedulers:
|
||||
scheduler.step()
|
||||
|
||||
self._update_grad_scaler(global_step)
|
||||
if self.apply_grad_scaler_step_tweaks:
|
||||
self._update_grad_scaler(global_step)
|
||||
|
||||
def _zero_grad(self, set_to_none=False):
|
||||
for optimizer in self.optimizers:
|
||||
|
|
53
train.py
53
train.py
|
@ -390,7 +390,8 @@ def main(args):
|
|||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -412,21 +413,22 @@ def main(args):
|
|||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
if save_ckpt:
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
|
||||
half = not save_full_precision
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
|
@ -616,7 +618,7 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
|
@ -731,6 +733,9 @@ def main(args):
|
|||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def make_save_path(epoch, global_step, prepend=""):
|
||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation(global_step=0,
|
||||
|
@ -816,13 +821,13 @@ def main(args):
|
|||
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
||||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
@ -846,9 +851,9 @@ def main(args):
|
|||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
epoch = args.max_epochs
|
||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -857,8 +862,8 @@ def main(args):
|
|||
|
||||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
@ -906,6 +911,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
|
||||
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
|
|
|
@ -176,7 +176,7 @@ class SampleGenerator:
|
|||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_config]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
|
|
Loading…
Reference in New Issue