Add the model card template (#43)
* add a metrics logger * fix LatentDiffusionUncondPipeline * add VQModel in init * add image logging to tensorboard * switch manual templates to the modelcards package * hide ldm example Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
parent
f47066f707
commit
8cba133f36
|
@ -4,14 +4,13 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
|
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
|
||||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import logging
|
|
||||||
from torchvision.transforms import (
|
from torchvision.transforms import (
|
||||||
CenterCrop,
|
CenterCrop,
|
||||||
Compose,
|
Compose,
|
||||||
|
@ -24,11 +23,12 @@ from torchvision.transforms import (
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||||
|
accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir)
|
||||||
|
|
||||||
model = UNetModel(
|
model = UNetModel(
|
||||||
attn_resolutions=(16,),
|
attn_resolutions=(16,),
|
||||||
|
@ -39,8 +39,14 @@ def main(args):
|
||||||
resamp_with_conv=True,
|
resamp_with_conv=True,
|
||||||
resolution=args.resolution,
|
resolution=args.resolution,
|
||||||
)
|
)
|
||||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=(args.adam_beta1, args.adam_beta2),
|
||||||
|
weight_decay=args.adam_weight_decay,
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
augmentations = Compose(
|
augmentations = Compose(
|
||||||
[
|
[
|
||||||
|
@ -58,12 +64,12 @@ def main(args):
|
||||||
return {"input": images}
|
return {"input": images}
|
||||||
|
|
||||||
dataset.set_transform(transforms)
|
dataset.set_transform(transforms)
|
||||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
"linear",
|
args.lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=args.warmup_steps,
|
num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,15 +82,19 @@ def main(args):
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo = init_git_repo(args, at_init=True)
|
repo = init_git_repo(args, at_init=True)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
run = os.path.split(__file__)[-1].split(".")[0]
|
||||||
|
accelerator.init_trackers(run)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||||
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
|
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
|
||||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||||
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
logger.info(f" Total optimization steps = {max_steps}")
|
logger.info(f" Total optimization steps = {max_steps}")
|
||||||
|
@ -92,65 +102,71 @@ def main(args):
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(args.num_epochs):
|
for epoch in range(args.num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
||||||
pbar.set_description(f"Epoch {epoch}")
|
progress_bar.set_description(f"Epoch {epoch}")
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
clean_images = batch["input"]
|
clean_images = batch["input"]
|
||||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||||
bsz = clean_images.shape[0]
|
bsz = clean_images.shape[0]
|
||||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
||||||
|
|
||||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
|
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
|
||||||
|
|
||||||
if step % args.gradient_accumulation_steps != 0:
|
if step % args.gradient_accumulation_steps != 0:
|
||||||
with accelerator.no_sync(model):
|
with accelerator.no_sync(model):
|
||||||
output = model(noisy_images, timesteps)
|
|
||||||
# predict the noise residual
|
|
||||||
loss = F.mse_loss(output, noise_samples)
|
|
||||||
loss = loss / args.gradient_accumulation_steps
|
|
||||||
accelerator.backward(loss)
|
|
||||||
else:
|
|
||||||
output = model(noisy_images, timesteps)
|
output = model(noisy_images, timesteps)
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
loss = F.mse_loss(output, noise_samples)
|
loss = F.mse_loss(output, noise_samples)
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
else:
|
||||||
optimizer.step()
|
output = model(noisy_images, timesteps)
|
||||||
lr_scheduler.step()
|
# predict the noise residual
|
||||||
ema_model.step(model, global_step)
|
loss = F.mse_loss(output, noise_samples)
|
||||||
optimizer.zero_grad()
|
loss = loss / args.gradient_accumulation_steps
|
||||||
pbar.update(1)
|
accelerator.backward(loss)
|
||||||
pbar.set_postfix(
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
|
optimizer.step()
|
||||||
)
|
lr_scheduler.step()
|
||||||
global_step += 1
|
ema_model.step(model, global_step)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
progress_bar.update(1)
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
|
||||||
|
)
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"train_loss": loss.detach().item(),
|
||||||
|
"epoch": epoch,
|
||||||
|
"ema_decay": ema_model.decay,
|
||||||
|
"step": global_step,
|
||||||
|
},
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
|
global_step += 1
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Generate a sample image for visual inspection
|
# Generate a sample image for visual inspection
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pipeline = DDPMPipeline(
|
pipeline = DDIMPipeline(
|
||||||
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
|
unet=accelerator.unwrap_model(ema_model.averaged_model),
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
# run pipeline in inference (sample random noise and denoise)
|
# run pipeline in inference (sample random noise and denoise)
|
||||||
image = pipeline(generator=generator)
|
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
|
||||||
|
|
||||||
# process image to PIL
|
# denormalize the images and save to tensorboard
|
||||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
images_processed = (images.cpu() + 1.0) * 127.5
|
||||||
image_processed = (image_processed + 1.0) * 127.5
|
images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
|
||||||
image_processed = image_processed.type(torch.uint8).numpy()
|
|
||||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
|
||||||
|
|
||||||
# save image
|
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
|
||||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
|
||||||
os.makedirs(test_dir, exist_ok=True)
|
|
||||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
|
@ -159,6 +175,8 @@ def main(args):
|
||||||
pipeline.save_pretrained(args.output_dir)
|
pipeline.save_pretrained(args.output_dir)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
@ -167,18 +185,25 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--output_dir", type=str, default="ddpm-model")
|
parser.add_argument("--output_dir", type=str, default="ddpm-model")
|
||||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||||
parser.add_argument("--resolution", type=int, default=64)
|
parser.add_argument("--resolution", type=int, default=64)
|
||||||
parser.add_argument("--batch_size", type=int, default=16)
|
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||||
parser.add_argument("--num_epochs", type=int, default=100)
|
parser.add_argument("--num_epochs", type=int, default=100)
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||||
parser.add_argument("--lr", type=float, default=1e-4)
|
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||||
|
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
||||||
|
parser.add_argument("--adam_beta1", type=float, default=0.95)
|
||||||
|
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||||
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||||
|
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
|
||||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||||
parser.add_argument("--ema_max_decay", type=float, default=0.999)
|
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
||||||
parser.add_argument("--push_to_hub", action="store_true")
|
parser.add_argument("--push_to_hub", action="store_true")
|
||||||
parser.add_argument("--hub_token", type=str, default=None)
|
parser.add_argument("--hub_token", type=str, default=None)
|
||||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||||
parser.add_argument("--hub_private_repo", action="store_true")
|
parser.add_argument("--hub_private_repo", action="store_true")
|
||||||
|
parser.add_argument("--logging_dir", type=str, default="logs")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -87,6 +87,8 @@ _deps = [
|
||||||
"regex!=2019.12.17",
|
"regex!=2019.12.17",
|
||||||
"requests",
|
"requests",
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
|
"tensorboard",
|
||||||
|
"modelcards=0.1.4"
|
||||||
]
|
]
|
||||||
|
|
||||||
# this is a lookup table with items like:
|
# this is a lookup table with items like:
|
||||||
|
@ -172,6 +174,8 @@ install_requires = [
|
||||||
deps["requests"],
|
deps["requests"],
|
||||||
deps["torch"],
|
deps["torch"],
|
||||||
deps["Pillow"],
|
deps["Pillow"],
|
||||||
|
deps["tensorboard"],
|
||||||
|
deps["modelcards"],
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
|
|
@ -13,4 +13,5 @@ deps = {
|
||||||
"regex": "regex!=2019.12.17",
|
"regex": "regex!=2019.12.17",
|
||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"torch": "torch>=1.4",
|
"torch": "torch>=1.4",
|
||||||
|
"tensorboard": "tensorboard",
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,9 @@ import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
from modelcards import CardData, ModelCard
|
||||||
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
@ -29,10 +29,7 @@ from .utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
AUTOGENERATED_TRAINER_COMMENT = """
|
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
||||||
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
|
||||||
should probably proofread and complete it, then remove this comment. -->
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||||
|
@ -152,17 +149,36 @@ def create_model_card(args, model_name):
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: replace this placeholder model card generation
|
repo_name = get_full_repo_name(model_name, token=args.hub_token)
|
||||||
model_card = ""
|
|
||||||
|
|
||||||
metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]}
|
model_card = ModelCard.from_template(
|
||||||
metadata = yaml.dump(metadata, sort_keys=False)
|
card_data=CardData( # Card metadata object that will be converted to YAML block
|
||||||
if len(metadata) > 0:
|
language="en",
|
||||||
model_card = f"---\n{metadata}---\n"
|
license="apache-2.0",
|
||||||
|
library_name="diffusers",
|
||||||
|
tags=[],
|
||||||
|
datasets=args.dataset,
|
||||||
|
metrics=[],
|
||||||
|
),
|
||||||
|
template_path=MODEL_CARD_TEMPLATE_PATH,
|
||||||
|
model_name=model_name,
|
||||||
|
repo_name=repo_name,
|
||||||
|
dataset_name=args.dataset,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
eval_batch_size=args.eval_batch_size,
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
adam_beta1=args.adam_beta1,
|
||||||
|
adam_beta2=args.adam_beta2,
|
||||||
|
adam_weight_decay=args.adam_weight_decay,
|
||||||
|
adam_epsilon=args.adam_epsilon,
|
||||||
|
lr_scheduler=args.lr_scheduler,
|
||||||
|
lr_warmup_steps=args.lr_warmup_steps,
|
||||||
|
ema_inv_gamma=args.ema_inv_gamma,
|
||||||
|
ema_power=args.ema_power,
|
||||||
|
ema_max_decay=args.ema_max_decay,
|
||||||
|
mixed_precision=args.mixed_precision,
|
||||||
|
)
|
||||||
|
|
||||||
model_card += AUTOGENERATED_TRAINER_COMMENT
|
card_path = os.path.join(args.output_dir, "README.md")
|
||||||
|
model_card.save(card_path)
|
||||||
model_card += f"\n# {model_name}\n\n"
|
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, "README.md"), "w") as f:
|
|
||||||
f.write(model_card)
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
|
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline, VQModel
|
||||||
|
|
|
@ -145,5 +145,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return pred_prev_sample
|
return pred_prev_sample
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timesteps):
|
||||||
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
|
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||||
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||||
|
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.timesteps
|
return self.config.timesteps
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
@ -142,7 +141,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return pred_prev_sample
|
return pred_prev_sample
|
||||||
|
|
||||||
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
|
def add_noise(self, original_samples, noise, timesteps):
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
---
|
||||||
|
{{ card_data }}
|
||||||
|
---
|
||||||
|
|
||||||
|
<!-- This model card has been generated automatically according to the information the training script had access to. You
|
||||||
|
should probably proofread and complete it, then remove this comment. -->
|
||||||
|
|
||||||
|
# {{ model_name | default("Diffusion Model") }}
|
||||||
|
|
||||||
|
## Model description
|
||||||
|
|
||||||
|
This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
|
||||||
|
on the `{{ dataset_name }}` dataset.
|
||||||
|
|
||||||
|
## Intended uses & limitations
|
||||||
|
|
||||||
|
#### How to use
|
||||||
|
|
||||||
|
```python
|
||||||
|
# TODO: add an example code snippet for running this diffusion pipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Limitations and bias
|
||||||
|
|
||||||
|
[TODO: provide examples of latent issues and potential remediations]
|
||||||
|
|
||||||
|
## Training data
|
||||||
|
|
||||||
|
[TODO: describe the data used to train the model]
|
||||||
|
|
||||||
|
### Training hyperparameters
|
||||||
|
|
||||||
|
The following hyperparameters were used during training:
|
||||||
|
- learning_rate: {{ learning_rate }}
|
||||||
|
- train_batch_size: {{ train_batch_size }}
|
||||||
|
- eval_batch_size: {{ eval_batch_size }}
|
||||||
|
- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
|
||||||
|
- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
|
||||||
|
- lr_scheduler: {{ lr_scheduler }}
|
||||||
|
- lr_warmup_steps: {{ lr_warmup_steps }}
|
||||||
|
- ema_inv_gamma: {{ ema_inv_gamma }}
|
||||||
|
- ema_inv_gamma: {{ ema_power }}
|
||||||
|
- ema_inv_gamma: {{ ema_max_decay }}
|
||||||
|
- mixed_precision: {{ mixed_precision }}
|
||||||
|
|
||||||
|
### Training results
|
||||||
|
|
||||||
|
📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue