make style
This commit is contained in:
parent
1a099e5e0e
commit
9c82c32ba7
|
@ -8,6 +8,9 @@ import PIL.Image
|
|||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.modeling_utils import unwrap_model
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
|
@ -19,10 +22,7 @@ from torchvision.transforms import (
|
|||
)
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from diffusers.modeling_utils import unwrap_model
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Optional
|
||||
from .utils import logging
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
import yaml
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from diffusers import DiffusionPipeline
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
|
|||
repo.git_pull()
|
||||
|
||||
# By default, ignore the checkpoint folders
|
||||
if (
|
||||
not os.path.exists(os.path.join(args.output_dir, ".gitignore"))
|
||||
and args.hub_strategy != "all_checkpoints"
|
||||
):
|
||||
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")) and args.hub_strategy != "all_checkpoints":
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
||||
writer.writelines(["checkpoint-*/"])
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
def push_to_hub(
|
||||
args,
|
||||
pipeline: DiffusionPipeline,
|
||||
repo: Repository,
|
||||
commit_message: Optional[str] = "End of training",
|
||||
blocking: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
|
||||
Parameters:
|
||||
|
@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
|
|||
return
|
||||
|
||||
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
|
||||
if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done:
|
||||
if (
|
||||
blocking
|
||||
and len(repo.command_queue) > 0
|
||||
and repo.command_queue[-1] is not None
|
||||
and not repo.command_queue[-1].is_done
|
||||
):
|
||||
repo.command_queue[-1]._process.kill()
|
||||
|
||||
git_head_commit_url = repo.push_to_hub(
|
||||
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
|
||||
)
|
||||
git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
|
||||
# push separately the model card to be independent from the rest of the model
|
||||
create_model_card(args, model_name=model_name)
|
||||
try:
|
||||
repo.push_to_hub(
|
||||
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
|
||||
)
|
||||
repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
|
||||
except EnvironmentError as exc:
|
||||
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
|
||||
|
||||
|
@ -133,10 +140,7 @@ def create_model_card(args, model_name):
|
|||
# TODO: replace this placeholder model card generation
|
||||
model_card = ""
|
||||
|
||||
metadata = {
|
||||
"license": "apache-2.0",
|
||||
"tags": ["pytorch", "diffusers"]
|
||||
}
|
||||
metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]}
|
||||
metadata = yaml.dump(metadata, sort_keys=False)
|
||||
if len(metadata) > 0:
|
||||
model_card = f"---\n{metadata}---\n"
|
||||
|
|
|
@ -5,6 +5,7 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
try:
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
|
@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
|
|||
return out + self.residual_conv(x)
|
||||
|
||||
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
|
@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
|
|||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
|
|
|
@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
return pred_prev_sample
|
||||
|
||||
def forward_step(self, original_sample, noise, t):
|
||||
sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5
|
||||
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
|
||||
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_sample
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ from diffusers import (
|
|||
GLIDESuperResUNetModel,
|
||||
LatentDiffusion,
|
||||
PNDMScheduler,
|
||||
UNetModel,
|
||||
UNetLDMModel,
|
||||
UNetGradTTSModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
|
@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
|||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetLDMModel
|
||||
|
||||
|
|
Loading…
Reference in New Issue