From a2117cb79724490057b4e9e8bbb4369ee8e4914c Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 21 Jun 2022 10:38:34 +0200 Subject: [PATCH] add push_to_hub --- examples/README.md | 12 +- .../{train_ddpm.py => train_unconditional.py} | 34 +++- src/diffusers/hub_utils.py | 149 ++++++++++++++++++ src/diffusers/modeling_utils.py | 14 ++ 4 files changed, 197 insertions(+), 12 deletions(-) rename examples/{train_ddpm.py => train_unconditional.py} (79%) create mode 100644 src/diffusers/hub_utils.py diff --git a/examples/README.md b/examples/README.md index 407ddd43..d3d1c1c6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,13 +1,13 @@ ## Training examples -### Flowers DDPM +### Unconditional Flowers The command to train a DDPM UNet model on the Oxford Flowers dataset: ```bash python -m torch.distributed.launch \ --nproc_per_node 4 \ - train_ddpm.py \ + train_unconditional.py \ --dataset="huggan/flowers-102-categories" \ --resolution=64 \ --output_path="flowers-ddpm" \ @@ -19,19 +19,19 @@ python -m torch.distributed.launch \ --mixed_precision=no ``` -A full ltraining run takes 2 hours on 4xV100 GPUs. +A full training run takes 2 hours on 4xV100 GPUs. -### Pokemon DDPM +### Unconditional Pokemon The command to train a DDPM UNet model on the Pokemon dataset: ```bash python -m torch.distributed.launch \ --nproc_per_node 4 \ - train_ddpm.py \ + train_unconditional.py \ --dataset="huggan/pokemon" \ --resolution=64 \ --output_path="pokemon-ddpm" \ @@ -43,6 +43,6 @@ python -m torch.distributed.launch \ --mixed_precision=no ``` -A full ltraining run takes 2 hours on 4xV100 GPUs. +A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/train_ddpm.py b/examples/train_unconditional.py similarity index 79% rename from examples/train_ddpm.py rename to examples/train_unconditional.py index 6c7333a7..d8b5c0c3 100644 --- a/examples/train_ddpm.py +++ b/examples/train_unconditional.py @@ -19,6 +19,12 @@ from torchvision.transforms import ( ) from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup +from diffusers.modeling_utils import unwrap_model +from diffusers.hub_utils import init_git_repo, push_to_hub + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) def main(args): @@ -64,6 +70,21 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + # Train! + world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1 + total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size + max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataloader.dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + for epoch in range(args.num_epochs): model.train() with tqdm(total=len(train_dataloader), unit="ba") as pbar: @@ -105,11 +126,11 @@ def main(args): if args.local_rank in [-1, 0]: model.eval() with torch.no_grad(): - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler) + pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler) + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) else: - pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler) - pipeline.save_pretrained(args.output_path) + pipeline.save_pretrained(args.output_path) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) @@ -130,15 +151,16 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--local_rank", type=int) + parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--output_path", type=str, default="ddpm-model") - parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--push_to_hub", action="store_true") parser.add_argument( "--mixed_precision", type=str, diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py new file mode 100644 index 00000000..aa1700b7 --- /dev/null +++ b/src/diffusers/hub_utils.py @@ -0,0 +1,149 @@ +from typing import Optional +from .utils import logging +from huggingface_hub import HfFolder, Repository, whoami +import yaml +import os +from pathlib import Path +import shutil +from diffusers import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +AUTOGENERATED_TRAINER_COMMENT = """ + +""" + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def init_git_repo(args, at_init: bool = False): + """ + Initializes a git repo in `args.hub_model_id`. + Args: + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is + `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped + out. + """ + if args.local_rank not in [-1, 0]: + return + use_auth_token = True if args.hub_token is None else args.hub_token + if args.hub_model_id is None: + repo_name = Path(args.output_dir).absolute().name + else: + repo_name = args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=args.hub_token) + + try: + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + private=args.hub_private_repo, + ) + except EnvironmentError: + if args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(args.output_dir) + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + ) + else: + raise + + repo.git_pull() + + # By default, ignore the checkpoint folders + if ( + not os.path.exists(os.path.join(args.output_dir, ".gitignore")) + and args.hub_strategy != "all_checkpoints" + ): + with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + return repo + + +def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`create_model_card`]. + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of + the commit and an object to track the progress of the commit if `blocking=True` + """ + + if args.hub_model_id is None: + model_name = Path(args.output_dir).name + else: + model_name = args.hub_model_id.split("/")[-1] + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving pipeline checkpoint to {output_dir}") + pipeline.save_pretrained(output_dir) + + # Only push from one node. + if args.local_rank not in [-1, 0]: + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done: + repo.command_queue[-1]._process.kill() + + git_head_commit_url = repo.push_to_hub( + commit_message=commit_message, blocking=blocking, auto_lfs_prune=True + ) + # push separately the model card to be independent from the rest of the model + create_model_card(args, model_name=model_name) + try: + repo.push_to_hub( + commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True + ) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + +def create_model_card(args, model_name): + if args.local_rank not in [-1, 0]: + return + + # TODO: replace this placeholder model card generation + model_card = "" + + metadata = { + "license": "apache-2.0", + "tags": ["pytorch", "diffusers"] + } + metadata = yaml.dump(metadata, sort_keys=False) + if len(metadata) > 0: + model_card = f"---\n{metadata}---\n" + + model_card += AUTOGENERATED_TRAINER_COMMENT + + model_card += f"\n# {model_name}\n\n" + + with open(os.path.join(args.output_dir, "README.md"), "w") as f: + f.write(model_card) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 2dd1b998..1fdbd13e 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -572,3 +572,17 @@ class ModelMixin(torch.nn.Module): return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model \ No newline at end of file