Add the model card template (#43)

* add a metrics logger

* fix LatentDiffusionUncondPipeline

* add VQModel in init

* add image logging to tensorboard

* switch manual templates to the modelcards package

* hide ldm example

Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
Anton Lozhkov 2022-06-29 15:37:23 +02:00 committed by GitHub
parent f47066f707
commit 8cba133f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 180 additions and 76 deletions

View File

@ -4,14 +4,13 @@ import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import logging
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
@ -24,11 +23,12 @@ from torchvision.transforms import (
from tqdm.auto import tqdm from tqdm.auto import tqdm
logger = logging.get_logger(__name__) logger = get_logger(__name__)
def main(args): def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir)
model = UNetModel( model = UNetModel(
attn_resolutions=(16,), attn_resolutions=(16,),
@ -39,8 +39,14 @@ def main(args):
resamp_with_conv=True, resamp_with_conv=True,
resolution=args.resolution, resolution=args.resolution,
) )
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
augmentations = Compose( augmentations = Compose(
[ [
@ -58,12 +64,12 @@ def main(args):
return {"input": images} return {"input": images}
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
"linear", args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.warmup_steps, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
) )
@ -76,15 +82,19 @@ def main(args):
if args.push_to_hub: if args.push_to_hub:
repo = init_git_repo(args, at_init=True) repo = init_git_repo(args, at_init=True)
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)
# Train! # Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1 world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}") logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}") logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
@ -92,8 +102,8 @@ def main(args):
global_step = 0 global_step = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
model.train() model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar: progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
pbar.set_description(f"Epoch {epoch}") progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device) noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
@ -102,7 +112,7 @@ def main(args):
# add noise onto the clean images according to the noise magnitude at each timestep # add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0: if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model): with accelerator.no_sync(model):
@ -122,35 +132,41 @@ def main(args):
lr_scheduler.step() lr_scheduler.step()
ema_model.step(model, global_step) ema_model.step(model, global_step)
optimizer.zero_grad() optimizer.zero_grad()
pbar.update(1) progress_bar.update(1)
pbar.set_postfix( progress_bar.set_postfix(
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
) )
accelerator.log(
{
"train_loss": loss.detach().item(),
"epoch": epoch,
"ema_decay": ema_model.decay,
"step": global_step,
},
step=global_step,
)
global_step += 1 global_step += 1
progress_bar.close()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Generate a sample image for visual inspection # Generate a sample image for visual inspection
if accelerator.is_main_process: if accelerator.is_main_process:
with torch.no_grad(): with torch.no_grad():
pipeline = DDPMPipeline( pipeline = DDIMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler unet=accelerator.unwrap_model(ema_model.averaged_model),
noise_scheduler=noise_scheduler,
) )
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator) images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
# process image to PIL # denormalize the images and save to tensorboard
image_processed = image.cpu().permute(0, 2, 3, 1) images_processed = (images.cpu() + 1.0) * 127.5
image_processed = (image_processed + 1.0) * 127.5 images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model # save the model
if args.push_to_hub: if args.push_to_hub:
@ -159,6 +175,8 @@ def main(args):
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
@ -167,18 +185,25 @@ if __name__ == "__main__":
parser.add_argument("--output_dir", type=str, default="ddpm-model") parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--overwrite_output_dir", action="store_true") parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--adam_beta1", type=float, default=0.95)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4) parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true") parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,

View File

@ -87,6 +87,8 @@ _deps = [
"regex!=2019.12.17", "regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"tensorboard",
"modelcards=0.1.4"
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
@ -172,6 +174,8 @@ install_requires = [
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["Pillow"], deps["Pillow"],
deps["tensorboard"],
deps["modelcards"],
] ]
setup( setup(

View File

@ -13,4 +13,5 @@ deps = {
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"tensorboard": "tensorboard",
} }

View File

@ -19,9 +19,9 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import yaml
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from modelcards import CardData, ModelCard
from .utils import logging from .utils import logging
@ -29,10 +29,7 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
AUTOGENERATED_TRAINER_COMMENT = """ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
@ -152,17 +149,36 @@ def create_model_card(args, model_name):
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
return return
# TODO: replace this placeholder model card generation repo_name = get_full_repo_name(model_name, token=args.hub_token)
model_card = ""
metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]} model_card = ModelCard.from_template(
metadata = yaml.dump(metadata, sort_keys=False) card_data=CardData( # Card metadata object that will be converted to YAML block
if len(metadata) > 0: language="en",
model_card = f"---\n{metadata}---\n" license="apache-2.0",
library_name="diffusers",
tags=[],
datasets=args.dataset,
metrics=[],
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_name=model_name,
repo_name=repo_name,
dataset_name=args.dataset,
learning_rate=args.learning_rate,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
adam_beta1=args.adam_beta1,
adam_beta2=args.adam_beta2,
adam_weight_decay=args.adam_weight_decay,
adam_epsilon=args.adam_epsilon,
lr_scheduler=args.lr_scheduler,
lr_warmup_steps=args.lr_warmup_steps,
ema_inv_gamma=args.ema_inv_gamma,
ema_power=args.ema_power,
ema_max_decay=args.ema_max_decay,
mixed_precision=args.mixed_precision,
)
model_card += AUTOGENERATED_TRAINER_COMMENT card_path = os.path.join(args.output_dir, "README.md")
model_card.save(card_path)
model_card += f"\n# {model_name}\n\n"
with open(os.path.join(args.output_dir, "README.md"), "w") as f:
f.write(model_card)

View File

@ -1 +1 @@
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline, VQModel

View File

@ -145,5 +145,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self): def __len__(self):
return self.config.timesteps return self.config.timesteps

View File

@ -17,7 +17,6 @@
import math import math
import numpy as np import numpy as np
import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@ -142,7 +141,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

View File

@ -0,0 +1,50 @@
---
{{ card_data }}
---
<!-- This model card has been generated automatically according to the information the training script had access to. You
should probably proofread and complete it, then remove this comment. -->
# {{ model_name | default("Diffusion Model") }}
## Model description
This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
on the `{{ dataset_name }}` dataset.
## Intended uses & limitations
#### How to use
```python
# TODO: add an example code snippet for running this diffusion pipeline
```
#### Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
## Training data
[TODO: describe the data used to train the model]
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: {{ learning_rate }}
- train_batch_size: {{ train_batch_size }}
- eval_batch_size: {{ eval_batch_size }}
- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
- lr_scheduler: {{ lr_scheduler }}
- lr_warmup_steps: {{ lr_warmup_steps }}
- ema_inv_gamma: {{ ema_inv_gamma }}
- ema_inv_gamma: {{ ema_power }}
- ema_inv_gamma: {{ ema_max_decay }}
- mixed_precision: {{ mixed_precision }}
### Training results
📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)