add push_to_hub

This commit is contained in:
anton-l 2022-06-21 10:38:34 +02:00
parent 8c1f51978c
commit a2117cb797
4 changed files with 197 additions and 12 deletions

View File

@ -1,13 +1,13 @@
## Training examples
### Flowers DDPM
### Unconditional Flowers
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_ddpm.py \
train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--resolution=64 \
--output_path="flowers-ddpm" \
@ -19,19 +19,19 @@ python -m torch.distributed.launch \
--mixed_precision=no
```
A full ltraining run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
<img src="https://user-images.githubusercontent.com/26864830/173855866-5628989f-856b-4725-a944-d6c09490b2df.png" width="500" />
### Pokemon DDPM
### Unconditional Pokemon
The command to train a DDPM UNet model on the Pokemon dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_ddpm.py \
train_unconditional.py \
--dataset="huggan/pokemon" \
--resolution=64 \
--output_path="pokemon-ddpm" \
@ -43,6 +43,6 @@ python -m torch.distributed.launch \
--mixed_precision=no
```
A full ltraining run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
<img src="https://user-images.githubusercontent.com/26864830/173856733-4f117f8c-97bd-4f51-8002-56b488c96df9.png" width="500" />

View File

@ -19,6 +19,12 @@ from torchvision.transforms import (
)
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
from diffusers.modeling_utils import unwrap_model
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.utils import logging
logger = logging.get_logger(__name__)
def main(args):
@ -64,6 +70,21 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Train!
world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
@ -105,11 +126,11 @@ def main(args):
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
pipeline.save_pretrained(args.output_path)
pipeline.save_pretrained(args.output_path)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
@ -130,15 +151,16 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--output_path", type=str, default="ddpm-model")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,

149
src/diffusers/hub_utils.py Normal file
View File

@ -0,0 +1,149 @@
from typing import Optional
from .utils import logging
from huggingface_hub import HfFolder, Repository, whoami
import yaml
import os
from pathlib import Path
import shutil
from diffusers import DiffusionPipeline
logger = logging.get_logger(__name__)
AUTOGENERATED_TRAINER_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def init_git_repo(args, at_init: bool = False):
"""
Initializes a git repo in `args.hub_model_id`.
Args:
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
"""
if args.local_rank not in [-1, 0]:
return
use_auth_token = True if args.hub_token is None else args.hub_token
if args.hub_model_id is None:
repo_name = Path(args.output_dir).absolute().name
else:
repo_name = args.hub_model_id
if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=args.hub_token)
try:
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
private=args.hub_private_repo,
)
except EnvironmentError:
if args.overwrite_output_dir and at_init:
# Try again after wiping output_dir
shutil.rmtree(args.output_dir)
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
else:
raise
repo.git_pull()
# By default, ignore the checkpoint folders
if (
not os.path.exists(os.path.join(args.output_dir, ".gitignore"))
and args.hub_strategy != "all_checkpoints"
):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
return repo
def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True`
"""
if args.hub_model_id is None:
model_name = Path(args.output_dir).name
else:
model_name = args.hub_model_id.split("/")[-1]
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving pipeline checkpoint to {output_dir}")
pipeline.save_pretrained(output_dir)
# Only push from one node.
if args.local_rank not in [-1, 0]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done:
repo.command_queue[-1]._process.kill()
git_head_commit_url = repo.push_to_hub(
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
)
# push separately the model card to be independent from the rest of the model
create_model_card(args, model_name=model_name)
try:
repo.push_to_hub(
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
)
except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
return git_head_commit_url
def create_model_card(args, model_name):
if args.local_rank not in [-1, 0]:
return
# TODO: replace this placeholder model card generation
model_card = ""
metadata = {
"license": "apache-2.0",
"tags": ["pytorch", "diffusers"]
}
metadata = yaml.dump(metadata, sort_keys=False)
if len(metadata) > 0:
model_card = f"---\n{metadata}---\n"
model_card += AUTOGENERATED_TRAINER_COMMENT
model_card += f"\n# {model_name}\n\n"
with open(os.path.join(args.output_dir, "README.md"), "w") as f:
f.write(model_card)

View File

@ -572,3 +572,17 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model