fix setup
This commit is contained in:
parent
2234877e01
commit
cbb19ee84e
|
@ -1,10 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import pathlib
|
||||
from modeling_ddim import DDIM
|
||||
import PIL.Image
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from modeling_ddim import DDIM
|
||||
|
||||
|
||||
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
|
||||
|
||||
for model_id in model_ids:
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
import tqdm
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class DDIM(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
|
|||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# get actual t and t-1
|
||||
train_step = inference_step_times[t]
|
||||
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
|
||||
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||
|
||||
# compute alphas
|
||||
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||
|
@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
|
|||
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
||||
|
||||
# compute relevant coefficients
|
||||
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
|
||||
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
|
||||
coeff_1 = (
|
||||
(alpha_prod_t_prev - alpha_prod_t).sqrt()
|
||||
* alpha_prod_t_prev_rsqrt
|
||||
* beta_prod_t_prev_sqrt
|
||||
/ beta_prod_t_sqrt
|
||||
* eta
|
||||
)
|
||||
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
|
||||
|
||||
# model forward
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
# !pip install diffusers
|
||||
from modeling_ddim import DDIM
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from modeling_ddim import DDIM
|
||||
|
||||
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
model_id = "fusing/ddpm-lsun-bedroom"
|
||||
|
||||
|
|
|
@ -1,11 +1,25 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import pathlib
|
||||
from modeling_ddpm import DDPM
|
||||
import PIL.Image
|
||||
|
||||
import numpy as np
|
||||
|
||||
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
|
||||
import PIL.Image
|
||||
from modeling_ddpm import DDPM
|
||||
|
||||
|
||||
model_ids = [
|
||||
"ddpm-lsun-cat",
|
||||
"ddpm-lsun-cat-ema",
|
||||
"ddpm-lsun-church-ema",
|
||||
"ddpm-lsun-church",
|
||||
"ddpm-lsun-bedroom",
|
||||
"ddpm-lsun-bedroom-ema",
|
||||
"ddpm-cifar10-ema",
|
||||
"ddpm-cifar10",
|
||||
"ddpm-celeba-hq",
|
||||
"ddpm-celeba-hq-ema",
|
||||
]
|
||||
|
||||
for model_id in model_ids:
|
||||
path = os.path.join("/home/patrick/images/hf", model_id)
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
import tqdm
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
|
|||
|
||||
self.unet.to(torch_device)
|
||||
# 1. Sample gaussian noise
|
||||
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
device=torch_device,
|
||||
generator=generator,
|
||||
)
|
||||
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
image_coeff = (
|
||||
(1 - self.noise_scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(self.noise_scheduler.get_alpha(t))
|
||||
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
|
||||
* self.noise_scheduler.get_beta(t)
|
||||
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
|
@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
|
|||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
||||
prev_variance = self.noise_scheduler.sample_variance(
|
||||
t, prev_image.shape, device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidanceScheduler,
|
||||
CLIPTextModel,
|
||||
GlideDDIMScheduler,
|
||||
GLIDESuperResUNetModel,
|
||||
GLIDETextToImageUNetModel,
|
||||
)
|
||||
from modeling_glide import GLIDE
|
||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||
|
||||
|
@ -22,7 +28,9 @@ config = CLIPTextConfig(
|
|||
use_padding_embeddings=True,
|
||||
)
|
||||
model = CLIPTextModel(config).eval()
|
||||
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
|
||||
tokenizer = GPT2Tokenizer(
|
||||
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
|
||||
)
|
||||
|
||||
hf_encoder = model.text_model
|
||||
|
||||
|
@ -97,10 +105,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
|
|||
|
||||
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
|
||||
|
||||
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
|
||||
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
|
||||
glide = GLIDE(
|
||||
text_unet=text2im_model,
|
||||
text_noise_scheduler=text_scheduler,
|
||||
text_encoder=model,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=superres_model,
|
||||
upscale_noise_scheduler=upscale_scheduler,
|
||||
)
|
||||
|
||||
glide.save_pretrained("./glide-base")
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,14 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidanceScheduler,
|
||||
CLIPTextModel,
|
||||
DiffusionPipeline,
|
||||
GlideDDIMScheduler,
|
||||
GLIDESuperResUNetModel,
|
||||
GLIDETextToImageUNetModel,
|
||||
)
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
|
@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
|
|||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GLIDESuperResUNetModel,
|
||||
upscale_noise_scheduler: GlideDDIMScheduler
|
||||
upscale_noise_scheduler: GlideDDIMScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
|
||||
text_unet=text_unet,
|
||||
text_noise_scheduler=text_noise_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet,
|
||||
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
||||
|
@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
|
|||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
scheduler.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
|
@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
|
|||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
||||
upsample_temp = 0.997
|
||||
|
||||
image = self.upscale_noise_scheduler.sample_noise(
|
||||
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
||||
) * upsample_temp
|
||||
image = (
|
||||
self.upscale_noise_scheduler.sample_noise(
|
||||
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
||||
)
|
||||
* upsample_temp
|
||||
)
|
||||
|
||||
num_timesteps = len(self.upscale_noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
|
||||
for t in tqdm.tqdm(
|
||||
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
||||
):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
|
||||
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
|
||||
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
image_coeff = (
|
||||
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
||||
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||
* self.upscale_noise_scheduler.get_beta(t)
|
||||
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
|
@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
|
|||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
|
||||
generator=generator)
|
||||
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
||||
t, prev_image.shape, device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import PIL.Image
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
|
@ -14,8 +16,8 @@ pipeline = DiffusionPipeline.from_pretrained(model_id)
|
|||
img = pipeline("a clip art of a hugging face", generator)
|
||||
|
||||
# process image to PIL
|
||||
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
image_pil = PIL.Image.fromarray(img)
|
||||
|
||||
# save image
|
||||
image_pil.save("test.png")
|
||||
image_pil.save("test.png")
|
||||
|
|
2
setup.py
2
setup.py
|
@ -84,6 +84,7 @@ _deps = [
|
|||
"isort>=5.5.4",
|
||||
"numpy",
|
||||
"pytest",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
|
@ -168,6 +169,7 @@ install_requires = [
|
|||
deps["filelock"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["regex"],
|
||||
deps["requests"],
|
||||
deps["torch"],
|
||||
deps["torchvision"],
|
||||
|
|
|
@ -7,7 +7,7 @@ __version__ = "0.0.1"
|
|||
from .modeling_utils import ModelMixin
|
||||
from .models.clip_text_transformer import CLIPTextModel
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.vqvae import VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
|
|
|
@ -23,13 +23,13 @@ import os
|
|||
import re
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
from huggingface_hub import hf_hub_download
|
||||
from requests import HTTPError
|
||||
|
||||
|
||||
from . import __version__
|
||||
from .utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
|
@ -37,9 +37,6 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
from . import __version__
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
@ -95,9 +92,7 @@ class ConfigMixin:
|
|||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
config_dict = cls.get_config_dict(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
|
@ -157,16 +152,16 @@ class ConfigMixin:
|
|||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
|
||||
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
|
||||
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
|
||||
" pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
||||
" this model name. Check the model page at"
|
||||
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
|
@ -174,14 +169,16 @@ class ConfigMixin:
|
|||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
||||
" run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
|
@ -195,9 +192,7 @@ class ConfigMixin:
|
|||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{config_file}' is not a valid JSON file."
|
||||
)
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
return config_dict
|
||||
|
||||
|
|
|
@ -3,29 +3,15 @@
|
|||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.9.0",
|
||||
"black": "black~=22.0,>=22.3",
|
||||
"codecarbon": "codecarbon==1.2.0",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"huggingface-hub": "huggingface-hub",
|
||||
"isort": "isort>=5.5.4",
|
||||
"numpy": "numpy>=1.17",
|
||||
"numpy": "numpy",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.7.0",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"sagemaker": "sagemaker>=2.31.0",
|
||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||
"torch": "torch>=1.4",
|
||||
"torchaudio": "torchaudio",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
"unidic_lite": "unidic_lite>=1.0.7",
|
||||
"uvicorn": "uvicorn",
|
||||
"torchvision": "torchvision",
|
||||
}
|
||||
|
|
|
@ -23,7 +23,8 @@ from pathlib import Path
|
|||
from typing import Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import cached_download
|
||||
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
|
||||
|
||||
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
|
|
@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from requests import HTTPError
|
||||
from huggingface_hub import hf_hub_download
|
||||
from requests import HTTPError
|
||||
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
|
@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
|
|||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
|
|
|
@ -18,6 +18,6 @@
|
|||
|
||||
from .clip_text_transformer import CLIPTextModel
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .vqvae import VQModel
|
||||
from .vqvae import VQModel
|
||||
|
|
|
@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
|
|||
from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from torchvision import transforms, utils
|
||||
from PIL import Image
|
||||
from torchvision import transforms, utils
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
|
||||
# dataset classes
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
|
||||
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
@ -359,7 +362,7 @@ class Dataset(data.Dataset):
|
|||
|
||||
|
||||
# trainer class
|
||||
class EMA():
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
|
|
@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=512
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=512,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
|
@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
|||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
|
@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
|||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
|
@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
|
|||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
||||
return self.out(h)
|
||||
return self.out(h)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from inspect import isfunction
|
||||
from abc import abstractmethod
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from inspect import isfunction
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from einops import repeat, rearrange
|
||||
from einops import rearrange, repeat
|
||||
except:
|
||||
print("Einops is not installed")
|
||||
pass
|
||||
|
@ -16,12 +17,13 @@ except:
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
|
@ -53,20 +55,13 @@ class GEGLU(nn.Module):
|
|||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
|
|||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
|
@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
|
|||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
|
|||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
|
|||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None):
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
|
@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
|
|||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
|
@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
|
|||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
|
@ -453,9 +427,7 @@ class Upsample(nn.Module):
|
|||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
|
@ -472,7 +444,7 @@ class Downsample(nn.Module):
|
|||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
@ -480,9 +452,7 @@ class Downsample(nn.Module):
|
|||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
|
|||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
|
@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
|
|||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += torch.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
|
@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
|
|||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
@ -773,14 +737,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
image_size=image_size,
|
||||
|
@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
assert (
|
||||
context_dim is not None
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
|
@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
|
@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
)
|
||||
|
@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
|
@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
),
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
|
@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
|
@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
)
|
||||
|
@ -1024,10 +990,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
|
@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
|
@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
|
|||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
|
|||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
|
@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
|
@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
|
|||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
|
|
@ -20,10 +20,9 @@ from typing import Optional, Union
|
|||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .utils import logging, DIFFUSERS_CACHE
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_model.pt"
|
||||
|
@ -106,7 +105,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Add docstrings
|
||||
Add docstrings
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
|
|
|
@ -11,12 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
||||
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
|
|
@ -11,12 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
||||
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_schedule="linear",
|
||||
variance_type="fixed_large"
|
||||
):
|
||||
def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
|
@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
|
|||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_timesteps
|
||||
return self.num_timesteps
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
import os
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -19,7 +21,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from requests.exceptions import HTTPError
|
||||
import os
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
|
|
|
@ -14,19 +14,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from models.vision.ddim.modeling_ddim import DDIM
|
||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase):
|
|||
ConfigMixin.from_config("dummy_path")
|
||||
|
||||
def test_save_load(self):
|
||||
|
||||
class SampleObject(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
|
@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase):
|
|||
|
||||
|
||||
class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_sample(self):
|
||||
generator = torch.manual_seed(0)
|
||||
|
@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
image = scheduler.sample_noise(
|
||||
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
image_coeff = (
|
||||
(1 - scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(scheduler.get_alpha(t))
|
||||
/ (1 - scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
|
@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280])
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
def test_sample_fast(self):
|
||||
|
@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
image = scheduler.sample_noise(
|
||||
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
image_coeff = (
|
||||
(1 - scheduler.get_alpha_prod(t - 1))
|
||||
* torch.sqrt(scheduler.get_alpha(t))
|
||||
/ (1 - scheduler.get_alpha_prod(t))
|
||||
)
|
||||
clipped_coeff = (
|
||||
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
)
|
||||
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
|
@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
|
@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068])
|
||||
expected_slice = torch.tensor(
|
||||
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
|
Loading…
Reference in New Issue