This commit is contained in:
patil-suraj 2022-06-09 14:12:43 +02:00
commit 9fc2b6c529
23 changed files with 356 additions and 318 deletions

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python3
import os
import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids:

View File

@ -14,13 +14,13 @@
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
import tqdm
from diffusers import DiffusionPipeline
class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
@ -34,7 +34,11 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
@ -50,7 +54,13 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward

View File

@ -1,9 +1,11 @@
#!/usr/bin/env python3
# !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom"

View File

@ -1,11 +1,25 @@
#!/usr/bin/env python3
import os
import pathlib
from modeling_ddpm import DDPM
import PIL.Image
import numpy as np
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id)

View File

@ -14,13 +14,13 @@
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
import tqdm
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
image_coeff = (
(1 - self.noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.noise_scheduler.get_alpha(t))
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
* self.noise_scheduler.get_beta(t)
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
prev_variance = self.noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance

View File

@ -1,7 +1,13 @@
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import (
ClassifierFreeGuidanceScheduler,
CLIPTextModel,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
@ -22,7 +28,9 @@ config = CLIPTextConfig(
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
@ -97,10 +105,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide = GLIDE(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")

View File

@ -18,7 +18,14 @@ import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import (
ClassifierFreeGuidanceScheduler,
CLIPTextModel,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from transformers import GPT2Tokenizer
@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
upscale_noise_scheduler: GlideDDIMScheduler,
):
super().__init__()
self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
text_unet=text_unet,
text_noise_scheduler=text_noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
)
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
image = (
self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
)
* upsample_temp
)
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
image_coeff = (
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
prev_variance = self.upscale_noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance

View File

@ -1,6 +1,8 @@
import torch
from diffusers import DiffusionPipeline
import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator()
generator = generator.manual_seed(0)

View File

@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4",
"numpy",
"pytest",
"regex!=2019.12.17",
"requests",
"torch>=1.4",
"torchvision",
@ -168,6 +169,7 @@ install_requires = [
deps["filelock"],
deps["huggingface-hub"],
deps["numpy"],
deps["regex"],
deps["requests"],
deps["torch"],
deps["torchvision"],

View File

@ -7,7 +7,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline

View File

@ -23,13 +23,13 @@ import os
import re
from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from huggingface_hub import hf_hub_download
from requests import HTTPError
from . import __version__
from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
@ -37,9 +37,6 @@ from .utils import (
)
from . import __version__
logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
@ -95,9 +92,7 @@ class ConfigMixin:
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
@ -157,16 +152,16 @@ class ConfigMixin:
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
" this model name. Check the model page at"
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
@ -174,14 +169,16 @@ class ConfigMixin:
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{config_file}' is not a valid JSON file."
)
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
return config_dict

View File

@ -3,29 +3,15 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0",
"dataclasses": "dataclasses",
"datasets": "datasets",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"huggingface-hub": "huggingface-hub",
"isort": "isort>=5.5.4",
"numpy": "numpy>=1.17",
"numpy": "numpy",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4",
"torchaudio": "torchaudio",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
"torchvision": "torchvision",
}

View File

@ -23,7 +23,8 @@ from pathlib import Path
from typing import Dict, Optional, Union
from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
from requests import HTTPError
from huggingface_hub import hf_hub_download
from requests import HTTPError
from .utils import (
CONFIG_NAME,
@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(

View File

@ -18,6 +18,6 @@
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel
from .vqvae import VQModel

View File

@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm
from ..configuration_utils import ConfigMixin
@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes
class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose([
self.transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class
class EMA():
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta

View File

@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
transformer_dim=512,
):
super().__init__(
in_channels=in_channels,
@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
transformer_dim=transformer_dim,
)
self.register(
in_channels=in_channels,
@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
transformer_dim=transformer_dim,
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)

View File

@ -1,14 +1,15 @@
from inspect import isfunction
from abc import abstractmethod
import math
from abc import abstractmethod
from inspect import isfunction
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from einops import repeat, rearrange
from einops import rearrange, repeat
except:
print("Einops is not installed")
pass
@ -16,6 +17,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def exists(val):
return val is not None
@ -53,20 +55,13 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
@ -96,11 +91,11 @@ class LinearAttention(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out)
@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
@ -140,25 +119,25 @@ class SpatialSelfAttention(nn.Module):
# compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -170,10 +149,7 @@ class CrossAttention(nn.Module):
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None):
h = self.heads
@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x + x_in
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
@ -453,9 +427,7 @@ class Upsample(nn.Module):
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
@ -480,9 +452,7 @@ class Downsample(nn.Module):
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
self.image_size = image_size
self.in_channels = in_channels
@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
)
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
)
if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock(
ch,
time_embed_dim,
@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
)
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs
**kwargs,
):
super().__init__()
@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
)
elif pool == "spatial":
self.out = nn.Sequential(
@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
else:
h = h.type(x.dtype)
return self.out(h)

View File

@ -20,10 +20,9 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt"

View File

@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"

View File

@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_schedule="linear",
variance_type="fixed_large"
):
def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
super().__init__()
self.register(
timesteps=timesteps,

View File

@ -5,6 +5,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -19,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from requests.exceptions import HTTPError
import os
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))

View File

@ -14,19 +14,19 @@
# limitations under the License.
import os
import random
import tempfile
import unittest
import os
from distutils.util import strtobool
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from models.vision.ddpm.modeling_ddpm import DDPM
from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM
from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random()
@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase):
ConfigMixin.from_config("dummy_path")
def test_save_load(self):
class SampleObject(ConfigMixin):
config_name = "config.json"
@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase):
@slow
def test_sample(self):
generator = torch.manual_seed(0)
@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
image_coeff = (
(1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280])
expected_slice = torch.tensor(
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_sample_fast(self):
@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
image_coeff = (
(1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068])
expected_slice = torch.tensor(
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2