diff --git a/examples/README.md b/examples/README.md
index 407ddd43..d3d1c1c6 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,13 +1,13 @@
## Training examples
-### Flowers DDPM
+### Unconditional Flowers
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
- train_ddpm.py \
+ train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--resolution=64 \
--output_path="flowers-ddpm" \
@@ -19,19 +19,19 @@ python -m torch.distributed.launch \
--mixed_precision=no
```
-A full ltraining run takes 2 hours on 4xV100 GPUs.
+A full training run takes 2 hours on 4xV100 GPUs.
-### Pokemon DDPM
+### Unconditional Pokemon
The command to train a DDPM UNet model on the Pokemon dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
- train_ddpm.py \
+ train_unconditional.py \
--dataset="huggan/pokemon" \
--resolution=64 \
--output_path="pokemon-ddpm" \
@@ -43,6 +43,6 @@ python -m torch.distributed.launch \
--mixed_precision=no
```
-A full ltraining run takes 2 hours on 4xV100 GPUs.
+A full training run takes 2 hours on 4xV100 GPUs.
diff --git a/examples/train_ddpm.py b/examples/train_unconditional.py
similarity index 79%
rename from examples/train_ddpm.py
rename to examples/train_unconditional.py
index 6c7333a7..d8b5c0c3 100644
--- a/examples/train_ddpm.py
+++ b/examples/train_unconditional.py
@@ -19,6 +19,12 @@ from torchvision.transforms import (
)
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
+from diffusers.modeling_utils import unwrap_model
+from diffusers.hub_utils import init_git_repo, push_to_hub
+
+from diffusers.utils import logging
+
+logger = logging.get_logger(__name__)
def main(args):
@@ -64,6 +70,21 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
+ if args.push_to_hub:
+ repo = init_git_repo(args, at_init=True)
+
+ # Train!
+ world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
+ total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
+ max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataloader.dataset)}")
+ logger.info(f" Num Epochs = {args.num_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps}")
+
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
@@ -105,11 +126,11 @@ def main(args):
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
+ pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
+ if args.push_to_hub:
+ push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
- pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
- pipeline.save_pretrained(args.output_path)
+ pipeline.save_pretrained(args.output_path)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
@@ -130,15 +151,16 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument("--local_rank", type=int)
+ parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--output_path", type=str, default="ddpm-model")
- parser.add_argument("--batch_size", type=int, default=16)
+ parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
+ parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,
diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py
new file mode 100644
index 00000000..aa1700b7
--- /dev/null
+++ b/src/diffusers/hub_utils.py
@@ -0,0 +1,149 @@
+from typing import Optional
+from .utils import logging
+from huggingface_hub import HfFolder, Repository, whoami
+import yaml
+import os
+from pathlib import Path
+import shutil
+from diffusers import DiffusionPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+AUTOGENERATED_TRAINER_COMMENT = """
+
+"""
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def init_git_repo(args, at_init: bool = False):
+ """
+ Initializes a git repo in `args.hub_model_id`.
+ Args:
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
+ `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
+ out.
+ """
+ if args.local_rank not in [-1, 0]:
+ return
+ use_auth_token = True if args.hub_token is None else args.hub_token
+ if args.hub_model_id is None:
+ repo_name = Path(args.output_dir).absolute().name
+ else:
+ repo_name = args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=args.hub_token)
+
+ try:
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ private=args.hub_private_repo,
+ )
+ except EnvironmentError:
+ if args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(args.output_dir)
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ )
+ else:
+ raise
+
+ repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if (
+ not os.path.exists(os.path.join(args.output_dir, ".gitignore"))
+ and args.hub_strategy != "all_checkpoints"
+ ):
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ return repo
+
+
+def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
+ """
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+ Parameters:
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`create_model_card`].
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
+ the commit and an object to track the progress of the commit if `blocking=True`
+ """
+
+ if args.hub_model_id is None:
+ model_name = Path(args.output_dir).name
+ else:
+ model_name = args.hub_model_id.split("/")[-1]
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
+ pipeline.save_pretrained(output_dir)
+
+ # Only push from one node.
+ if args.local_rank not in [-1, 0]:
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done:
+ repo.command_queue[-1]._process.kill()
+
+ git_head_commit_url = repo.push_to_hub(
+ commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
+ )
+ # push separately the model card to be independent from the rest of the model
+ create_model_card(args, model_name=model_name)
+ try:
+ repo.push_to_hub(
+ commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
+ )
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+
+def create_model_card(args, model_name):
+ if args.local_rank not in [-1, 0]:
+ return
+
+ # TODO: replace this placeholder model card generation
+ model_card = ""
+
+ metadata = {
+ "license": "apache-2.0",
+ "tags": ["pytorch", "diffusers"]
+ }
+ metadata = yaml.dump(metadata, sort_keys=False)
+ if len(metadata) > 0:
+ model_card = f"---\n{metadata}---\n"
+
+ model_card += AUTOGENERATED_TRAINER_COMMENT
+
+ model_card += f"\n# {model_name}\n\n"
+
+ with open(os.path.join(args.output_dir, "README.md"), "w") as f:
+ f.write(model_card)
diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py
index 2dd1b998..1fdbd13e 100644
--- a/src/diffusers/modeling_utils.py
+++ b/src/diffusers/modeling_utils.py
@@ -572,3 +572,17 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
\ No newline at end of file