fix setup

This commit is contained in:
Patrick von Platen 2022-06-09 14:06:58 +02:00
parent 2234877e01
commit cbb19ee84e
23 changed files with 356 additions and 318 deletions

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"] model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids: for model_id in model_ids:

View File

@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1 # get actual t and t-1
train_step = inference_step_times[t] train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1 prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas # compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients # compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta coeff_1 = (
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() (alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward # model forward
with torch.no_grad(): with torch.no_grad():

View File

@ -1,9 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# !pip install diffusers # !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom" model_id = "fusing/ddpm-lsun-bedroom"

View File

@ -1,11 +1,25 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddpm import DDPM
import PIL.Image
import numpy as np import numpy as np
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"] import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids: for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id) path = os.path.join("/home/patrick/images/hf", model_id)

View File

@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) (1 - self.noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.noise_scheduler.get_alpha(t))
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
* self.noise_scheduler.get_beta(t)
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = self.noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance

View File

@ -1,7 +1,13 @@
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import (
ClassifierFreeGuidanceScheduler,
CLIPTextModel,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
@ -22,7 +28,9 @@ config = CLIPTextConfig(
use_padding_embeddings=True, use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model hf_encoder = model.text_model
@ -97,10 +105,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, glide = GLIDE(
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler) text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")

View File

@ -18,7 +18,14 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import (
ClassifierFreeGuidanceScheduler,
CLIPTextModel,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler upscale_noise_scheduler: GlideDDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_unet=text_unet,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler text_noise_scheduler=text_noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
) )
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
== posterior_variance.shape[0] == posterior_variance.shape[0]
@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts. # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997 upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise( image = (
(batch_size, 3, 256, 256), device=torch_device, generator=generator self.upscale_noise_scheduler.sample_noise(
) * upsample_temp (batch_size, 3, 256, 256), device=torch_device, generator=generator
)
* upsample_temp
)
num_timesteps = len(self.upscale_noise_scheduler) num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( image_coeff = (
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([t] * image.shape[0], device=torch_device)
@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, prev_variance = self.upscale_noise_scheduler.sample_variance(
generator=generator) t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance

View File

@ -1,6 +1,8 @@
import torch import torch
from diffusers import DiffusionPipeline
import PIL.Image import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
@ -14,8 +16,8 @@ pipeline = DiffusionPipeline.from_pretrained(model_id)
img = pipeline("a clip art of a hugging face", generator) img = pipeline("a clip art of a hugging face", generator)
# process image to PIL # process image to PIL
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img) image_pil = PIL.Image.fromarray(img)
# save image # save image
image_pil.save("test.png") image_pil.save("test.png")

View File

@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"numpy", "numpy",
"pytest", "pytest",
"regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"torchvision", "torchvision",
@ -168,6 +169,7 @@ install_requires = [
deps["filelock"], deps["filelock"],
deps["huggingface-hub"], deps["huggingface-hub"],
deps["numpy"], deps["numpy"],
deps["regex"],
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["torchvision"], deps["torchvision"],

View File

@ -7,7 +7,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline

View File

@ -23,13 +23,13 @@ import os
import re import re
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from . import __version__
from .utils import ( from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
@ -37,9 +37,6 @@ from .utils import (
) )
from . import __version__
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json") _re_configuration_file = re.compile(r"config\.(.*)\.json")
@ -95,9 +92,7 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict( config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
@ -157,16 +152,16 @@ class ConfigMixin:
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
"`use_auth_token=True`." " pass `use_auth_token=True`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " " this model name. Check the model page at"
"available revisions." f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
@ -174,14 +169,16 @@ class ConfigMixin:
) )
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." " run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
raise EnvironmentError( raise EnvironmentError(
@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
f"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return config_dict return config_dict

View File

@ -3,29 +3,15 @@
# 2. run `make deps_table_update`` # 2. run `make deps_table_update``
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3", "black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0", "filelock": "filelock",
"dataclasses": "dataclasses", "flake8": "flake8>=3.8.3",
"datasets": "datasets", "huggingface-hub": "huggingface-hub",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"numpy": "numpy>=1.17", "numpy": "numpy",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchaudio": "torchaudio", "torchvision": "torchvision",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
} }

View File

@ -23,7 +23,8 @@ from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import cached_download from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.") raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
)
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(

View File

@ -18,6 +18,6 @@
from .clip_text_transformer import CLIPTextModel from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .vqvae import VQModel from .vqvae import VQModel

View File

@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam from torch.optim import Adam
from torch.utils import data from torch.utils import data
from torchvision import transforms, utils
from PIL import Image from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes # dataset classes
class Dataset(data.Dataset): class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__() super().__init__()
self.folder = folder self.folder = folder
self.image_size = image_size self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose([ self.transform = transforms.Compose(
transforms.Resize(image_size), [
transforms.RandomHorizontalFlip(), transforms.Resize(image_size),
transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),
transforms.ToTensor() transforms.CenterCrop(image_size),
]) transforms.ToTensor(),
]
)
def __len__(self): def __len__(self):
return len(self.paths) return len(self.paths)
@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class # trainer class
class EMA(): class EMA:
def __init__(self, beta): def __init__(self, beta):
super().__init__() super().__init__()
self.beta = beta self.beta = beta

View File

@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
""" """
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
attention_resolutions=(2, 4, 8), attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
transformer_dim=512 transformer_dim=512,
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.register( self.register(
in_channels=in_channels, in_channels=in_channels,
@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
""" """
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
attention_resolutions=(2, 4, 8), attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
return self.out(h) return self.out(h)

View File

@ -1,14 +1,15 @@
from inspect import isfunction
from abc import abstractmethod
import math import math
from abc import abstractmethod
from inspect import isfunction
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from einops import repeat, rearrange from einops import rearrange, repeat
except: except:
print("Einops is not installed") print("Einops is not installed")
pass pass
@ -16,12 +17,13 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
def exists(val): def exists(val):
return val is not None return val is not None
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
@ -53,20 +55,13 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential( self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out) return self.to_out(out)
@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
in_channels, self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
kernel_size=1, self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
stride=1, self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b,c,h,w = q.shape b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, 'b c h w -> b c (h w)') k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum('bij,bjk->bik', q, k) w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values # attend to values
v = rearrange(v, 'b c h w -> b c (h w)') v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, 'b i j -> b j i') w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
attn = sim.softmax(dim=-1) attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v) out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) [
for d in range(depth)] BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
x = self.proj_in(x) x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c') x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks: for block in self.transformer_blocks:
x = block(x, context=context) x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
def convert_module_to_f16(l): def convert_module_to_f16(l):
""" """
Convert primitive modules to float16. Convert primitive modules to float16.
@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
@ -453,9 +427,7 @@ class Upsample(nn.Module):
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv: if self.use_conv:
@ -472,7 +444,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -480,9 +452,7 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
dims, channels, self.out_channels, 3, padding=1
)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops. # We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes # The first computes the weight matrix, the second computes
# the combination of the value vectors. # the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += torch.DoubleTensor([matmul_ops]) model.total_ops += torch.DoubleTensor([matmul_ops])
@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum( weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v) a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@ -773,14 +737,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True, legacy=True,
): ):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register(
image_size=image_size, image_size=image_size,
@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
if num_heads == -1: if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) )
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(
@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
) )
) )
@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
) )
ch = out_ch ch = out_ch
@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
), else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(
@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
) )
) )
@ -1024,10 +990,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
normalization(ch), normalization(ch),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
use_new_attention_order=False, use_new_attention_order=False,
pool="adaptive", pool="adaptive",
*args, *args,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
) )
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) )
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
) )
ch = out_ch ch = out_ch
@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
AttentionPool2d( AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
(image_size // ds), ch, num_head_channels, out_channels
),
) )
elif pool == "spatial": elif pool == "spatial":
self.out = nn.Sequential( self.out = nn.Sequential(
@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
else: else:
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)

View File

@ -20,10 +20,9 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_model.pt"
@ -106,7 +105,7 @@ class DiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Add docstrings Add docstrings
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)

View File

@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math import math
import torch
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"

View File

@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
self,
timesteps=1000,
beta_schedule="linear",
variance_type="fixed_large"
):
super().__init__() super().__init__()
self.register( self.register(
timesteps=timesteps, timesteps=timesteps,
@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.num_timesteps return self.num_timesteps

View File

@ -5,6 +5,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
import os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -19,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
import os
hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))

View File

@ -14,19 +14,19 @@
# limitations under the License. # limitations under the License.
import os
import random import random
import tempfile import tempfile
import unittest import unittest
import os
from distutils.util import strtobool from distutils.util import strtobool
import torch import torch
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from models.vision.ddpm.modeling_ddpm import DDPM from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM from models.vision.ddim.modeling_ddim import DDIM
from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random() global_rng = random.Random()
@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase):
ConfigMixin.from_config("dummy_path") ConfigMixin.from_config("dummy_path")
def test_save_load(self): def test_save_load(self):
class SampleObject(ConfigMixin): class SampleObject(ConfigMixin):
config_name = "config.json" config_name = "config.json"
@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase):
@slow @slow
def test_sample(self): def test_sample(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) (1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]) expected_slice = torch.tensor(
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_sample_fast(self): def test_sample_fast(self):
@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) (1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]) expected_slice = torch.tensor(
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2