fix setup
This commit is contained in:
parent
2234877e01
commit
cbb19ee84e
|
@ -1,10 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from modeling_ddim import DDIM
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
from modeling_ddim import DDIM
|
||||||
|
|
||||||
|
|
||||||
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
|
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
|
||||||
|
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
class DDIM(DiffusionPipeline):
|
class DDIM(DiffusionPipeline):
|
||||||
|
|
||||||
def __init__(self, unet, noise_scheduler):
|
def __init__(self, unet, noise_scheduler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||||
|
@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
|
||||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||||
|
|
||||||
self.unet.to(torch_device)
|
self.unet.to(torch_device)
|
||||||
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
image = self.noise_scheduler.sample_noise(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||||
|
device=torch_device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||||
# get actual t and t-1
|
# get actual t and t-1
|
||||||
train_step = inference_step_times[t]
|
train_step = inference_step_times[t]
|
||||||
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
|
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||||
|
|
||||||
# compute alphas
|
# compute alphas
|
||||||
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||||
|
@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
|
||||||
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
||||||
|
|
||||||
# compute relevant coefficients
|
# compute relevant coefficients
|
||||||
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
|
coeff_1 = (
|
||||||
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
|
(alpha_prod_t_prev - alpha_prod_t).sqrt()
|
||||||
|
* alpha_prod_t_prev_rsqrt
|
||||||
|
* beta_prod_t_prev_sqrt
|
||||||
|
/ beta_prod_t_sqrt
|
||||||
|
* eta
|
||||||
|
)
|
||||||
|
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
|
||||||
|
|
||||||
# model forward
|
# model forward
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# !pip install diffusers
|
# !pip install diffusers
|
||||||
from modeling_ddim import DDIM
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
from modeling_ddim import DDIM
|
||||||
|
|
||||||
|
|
||||||
model_id = "fusing/ddpm-cifar10"
|
model_id = "fusing/ddpm-cifar10"
|
||||||
model_id = "fusing/ddpm-lsun-bedroom"
|
model_id = "fusing/ddpm-lsun-bedroom"
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,25 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from modeling_ddpm import DDPM
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
|
import PIL.Image
|
||||||
|
from modeling_ddpm import DDPM
|
||||||
|
|
||||||
|
|
||||||
|
model_ids = [
|
||||||
|
"ddpm-lsun-cat",
|
||||||
|
"ddpm-lsun-cat-ema",
|
||||||
|
"ddpm-lsun-church-ema",
|
||||||
|
"ddpm-lsun-church",
|
||||||
|
"ddpm-lsun-bedroom",
|
||||||
|
"ddpm-lsun-bedroom-ema",
|
||||||
|
"ddpm-cifar10-ema",
|
||||||
|
"ddpm-cifar10",
|
||||||
|
"ddpm-celeba-hq",
|
||||||
|
"ddpm-celeba-hq-ema",
|
||||||
|
]
|
||||||
|
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
path = os.path.join("/home/patrick/images/hf", model_id)
|
path = os.path.join("/home/patrick/images/hf", model_id)
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
class DDPM(DiffusionPipeline):
|
class DDPM(DiffusionPipeline):
|
||||||
|
|
||||||
def __init__(self, unet, noise_scheduler):
|
def __init__(self, unet, noise_scheduler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||||
|
@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
|
||||||
|
|
||||||
self.unet.to(torch_device)
|
self.unet.to(torch_device)
|
||||||
# 1. Sample gaussian noise
|
# 1. Sample gaussian noise
|
||||||
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
image = self.noise_scheduler.sample_noise(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||||
|
device=torch_device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
||||||
# i) define coefficients for time step t
|
# i) define coefficients for time step t
|
||||||
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
|
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
|
||||||
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
|
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
|
||||||
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
image_coeff = (
|
||||||
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
(1 - self.noise_scheduler.get_alpha_prod(t - 1))
|
||||||
|
* torch.sqrt(self.noise_scheduler.get_alpha(t))
|
||||||
|
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
clipped_coeff = (
|
||||||
|
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
|
||||||
|
* self.noise_scheduler.get_beta(t)
|
||||||
|
/ (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
|
||||||
# ii) predict noise residual
|
# ii) predict noise residual
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
|
||||||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||||
|
|
||||||
# iv) sample variance
|
# iv) sample variance
|
||||||
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
prev_variance = self.noise_scheduler.sample_variance(
|
||||||
|
t, prev_image.shape, device=torch_device, generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||||
sampled_prev_image = prev_image + prev_variance
|
sampled_prev_image = prev_image + prev_variance
|
||||||
|
|
|
@ -1,7 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
from diffusers import (
|
||||||
|
ClassifierFreeGuidanceScheduler,
|
||||||
|
CLIPTextModel,
|
||||||
|
GlideDDIMScheduler,
|
||||||
|
GLIDESuperResUNetModel,
|
||||||
|
GLIDETextToImageUNetModel,
|
||||||
|
)
|
||||||
from modeling_glide import GLIDE
|
from modeling_glide import GLIDE
|
||||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||||
|
|
||||||
|
@ -22,7 +28,9 @@ config = CLIPTextConfig(
|
||||||
use_padding_embeddings=True,
|
use_padding_embeddings=True,
|
||||||
)
|
)
|
||||||
model = CLIPTextModel(config).eval()
|
model = CLIPTextModel(config).eval()
|
||||||
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
|
tokenizer = GPT2Tokenizer(
|
||||||
|
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
|
||||||
|
)
|
||||||
|
|
||||||
hf_encoder = model.text_model
|
hf_encoder = model.text_model
|
||||||
|
|
||||||
|
@ -97,10 +105,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||||
|
|
||||||
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
|
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
|
||||||
|
|
||||||
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
|
glide = GLIDE(
|
||||||
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
|
text_unet=text2im_model,
|
||||||
|
text_noise_scheduler=text_scheduler,
|
||||||
|
text_encoder=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
upscale_unet=superres_model,
|
||||||
|
upscale_noise_scheduler=upscale_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
glide.save_pretrained("./glide-base")
|
glide.save_pretrained("./glide-base")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,14 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
from diffusers import (
|
||||||
|
ClassifierFreeGuidanceScheduler,
|
||||||
|
CLIPTextModel,
|
||||||
|
DiffusionPipeline,
|
||||||
|
GlideDDIMScheduler,
|
||||||
|
GLIDESuperResUNetModel,
|
||||||
|
GLIDETextToImageUNetModel,
|
||||||
|
)
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: GPT2Tokenizer,
|
tokenizer: GPT2Tokenizer,
|
||||||
upscale_unet: GLIDESuperResUNetModel,
|
upscale_unet: GLIDESuperResUNetModel,
|
||||||
upscale_noise_scheduler: GlideDDIMScheduler
|
upscale_noise_scheduler: GlideDDIMScheduler,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
|
text_unet=text_unet,
|
||||||
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
|
text_noise_scheduler=text_noise_scheduler,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
upscale_unet=upscale_unet,
|
||||||
|
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
||||||
|
@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
|
||||||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
)
|
)
|
||||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
||||||
posterior_log_variance_clipped = _extract_into_tensor(
|
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
scheduler.posterior_log_variance_clipped, t, x_t.shape
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
posterior_mean.shape[0]
|
posterior_mean.shape[0]
|
||||||
== posterior_variance.shape[0]
|
== posterior_variance.shape[0]
|
||||||
|
@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
|
||||||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
||||||
upsample_temp = 0.997
|
upsample_temp = 0.997
|
||||||
|
|
||||||
image = self.upscale_noise_scheduler.sample_noise(
|
image = (
|
||||||
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
self.upscale_noise_scheduler.sample_noise(
|
||||||
) * upsample_temp
|
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
||||||
|
)
|
||||||
|
* upsample_temp
|
||||||
|
)
|
||||||
|
|
||||||
num_timesteps = len(self.upscale_noise_scheduler)
|
num_timesteps = len(self.upscale_noise_scheduler)
|
||||||
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
|
for t in tqdm.tqdm(
|
||||||
|
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
||||||
|
):
|
||||||
# i) define coefficients for time step t
|
# i) define coefficients for time step t
|
||||||
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||||
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
||||||
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
|
image_coeff = (
|
||||||
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||||
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
|
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
||||||
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
clipped_coeff = (
|
||||||
|
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
||||||
|
* self.upscale_noise_scheduler.get_beta(t)
|
||||||
|
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
|
||||||
# ii) predict noise residual
|
# ii) predict noise residual
|
||||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||||
|
@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
|
||||||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||||
|
|
||||||
# iv) sample variance
|
# iv) sample variance
|
||||||
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
|
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
||||||
generator=generator)
|
t, prev_image.shape, device=torch_device, generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||||
sampled_prev_image = prev_image + prev_variance
|
sampled_prev_image = prev_image + prev_variance
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator = generator.manual_seed(0)
|
generator = generator.manual_seed(0)
|
||||||
|
@ -14,8 +16,8 @@ pipeline = DiffusionPipeline.from_pretrained(model_id)
|
||||||
img = pipeline("a clip art of a hugging face", generator)
|
img = pipeline("a clip art of a hugging face", generator)
|
||||||
|
|
||||||
# process image to PIL
|
# process image to PIL
|
||||||
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||||
image_pil = PIL.Image.fromarray(img)
|
image_pil = PIL.Image.fromarray(img)
|
||||||
|
|
||||||
# save image
|
# save image
|
||||||
image_pil.save("test.png")
|
image_pil.save("test.png")
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -84,6 +84,7 @@ _deps = [
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pytest",
|
"pytest",
|
||||||
|
"regex!=2019.12.17",
|
||||||
"requests",
|
"requests",
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
@ -168,6 +169,7 @@ install_requires = [
|
||||||
deps["filelock"],
|
deps["filelock"],
|
||||||
deps["huggingface-hub"],
|
deps["huggingface-hub"],
|
||||||
deps["numpy"],
|
deps["numpy"],
|
||||||
|
deps["regex"],
|
||||||
deps["requests"],
|
deps["requests"],
|
||||||
deps["torch"],
|
deps["torch"],
|
||||||
deps["torchvision"],
|
deps["torchvision"],
|
||||||
|
|
|
@ -7,7 +7,7 @@ __version__ = "0.0.1"
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
from .models.clip_text_transformer import CLIPTextModel
|
from .models.clip_text_transformer import CLIPTextModel
|
||||||
from .models.unet import UNetModel
|
from .models.unet import UNetModel
|
||||||
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||||
from .models.unet_ldm import UNetLDMModel
|
from .models.unet_ldm import UNetLDMModel
|
||||||
from .models.vqvae import VQModel
|
from .models.vqvae import VQModel
|
||||||
from .pipeline_utils import DiffusionPipeline
|
from .pipeline_utils import DiffusionPipeline
|
||||||
|
|
|
@ -23,13 +23,13 @@ import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
from requests import HTTPError
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
from .utils import (
|
from .utils import (
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
|
@ -37,9 +37,6 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from . import __version__
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||||
|
@ -95,9 +92,7 @@ class ConfigMixin:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||||
config_dict = cls.get_config_dict(
|
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
@ -157,16 +152,16 @@ class ConfigMixin:
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
|
||||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
|
||||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
|
||||||
"`use_auth_token=True`."
|
" pass `use_auth_token=True`."
|
||||||
)
|
)
|
||||||
except RevisionNotFoundError:
|
except RevisionNotFoundError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
||||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
" this model name. Check the model page at"
|
||||||
"available revisions."
|
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
)
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
|
@ -174,14 +169,16 @@ class ConfigMixin:
|
||||||
)
|
)
|
||||||
except HTTPError as err:
|
except HTTPError as err:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
"There was a specific connection error when trying to load"
|
||||||
|
f" {pretrained_model_name_or_path}:\n{err}"
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||||
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
|
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
||||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
" run the library in offline mode at"
|
||||||
|
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
|
@ -195,9 +192,7 @@ class ConfigMixin:
|
||||||
# Load config dict
|
# Load config dict
|
||||||
config_dict = cls._dict_from_json_file(config_file)
|
config_dict = cls._dict_from_json_file(config_file)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||||
f"It looks like the config file at '{config_file}' is not a valid JSON file."
|
|
||||||
)
|
|
||||||
|
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
|
@ -3,29 +3,15 @@
|
||||||
# 2. run `make deps_table_update``
|
# 2. run `make deps_table_update``
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"accelerate": "accelerate>=0.9.0",
|
|
||||||
"black": "black~=22.0,>=22.3",
|
"black": "black~=22.0,>=22.3",
|
||||||
"codecarbon": "codecarbon==1.2.0",
|
"filelock": "filelock",
|
||||||
"dataclasses": "dataclasses",
|
"flake8": "flake8>=3.8.3",
|
||||||
"datasets": "datasets",
|
"huggingface-hub": "huggingface-hub",
|
||||||
"GitPython": "GitPython<3.1.19",
|
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
|
||||||
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
|
|
||||||
"importlib_metadata": "importlib_metadata",
|
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"numpy": "numpy>=1.17",
|
"numpy": "numpy",
|
||||||
"pytest": "pytest",
|
"pytest": "pytest",
|
||||||
"pytest-timeout": "pytest-timeout",
|
|
||||||
"pytest-xdist": "pytest-xdist",
|
|
||||||
"python": "python>=3.7.0",
|
|
||||||
"regex": "regex!=2019.12.17",
|
"regex": "regex!=2019.12.17",
|
||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"sagemaker": "sagemaker>=2.31.0",
|
|
||||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
|
|
||||||
"torch": "torch>=1.4",
|
"torch": "torch>=1.4",
|
||||||
"torchaudio": "torchaudio",
|
"torchvision": "torchvision",
|
||||||
"tqdm": "tqdm>=4.27",
|
|
||||||
"unidic": "unidic>=1.0.2",
|
|
||||||
"unidic_lite": "unidic_lite>=1.0.7",
|
|
||||||
"uvicorn": "uvicorn",
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,8 @@ from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import cached_download
|
from huggingface_hub import cached_download
|
||||||
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
|
|
||||||
|
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
|
@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
|
|
||||||
from requests import HTTPError
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
|
@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
)
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
|
||||||
|
)
|
||||||
except HTTPError as err:
|
except HTTPError as err:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
"There was a specific connection error when trying to load"
|
||||||
|
f" {pretrained_model_name_or_path}:\n{err}"
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
|
|
|
@ -18,6 +18,6 @@
|
||||||
|
|
||||||
from .clip_text_transformer import CLIPTextModel
|
from .clip_text_transformer import CLIPTextModel
|
||||||
from .unet import UNetModel
|
from .unet import UNetModel
|
||||||
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||||
from .unet_ldm import UNetLDMModel
|
from .unet_ldm import UNetLDMModel
|
||||||
from .vqvae import VQModel
|
from .vqvae import VQModel
|
||||||
|
|
|
@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
|
|
||||||
from torchvision import transforms, utils
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision import transforms, utils
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
|
@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||||
|
|
||||||
# dataset classes
|
# dataset classes
|
||||||
|
|
||||||
|
|
||||||
class Dataset(data.Dataset):
|
class Dataset(data.Dataset):
|
||||||
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
|
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.folder = folder
|
self.folder = folder
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
|
||||||
|
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose(
|
||||||
transforms.Resize(image_size),
|
[
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.Resize(image_size),
|
||||||
transforms.CenterCrop(image_size),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.ToTensor()
|
transforms.CenterCrop(image_size),
|
||||||
])
|
transforms.ToTensor(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.paths)
|
return len(self.paths)
|
||||||
|
@ -359,7 +362,7 @@ class Dataset(data.Dataset):
|
||||||
|
|
||||||
|
|
||||||
# trainer class
|
# trainer class
|
||||||
class EMA():
|
class EMA:
|
||||||
def __init__(self, beta):
|
def __init__(self, beta):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
|
@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
model_channels=192,
|
model_channels=192,
|
||||||
out_channels=6,
|
out_channels=6,
|
||||||
num_res_blocks=3,
|
num_res_blocks=3,
|
||||||
attention_resolutions=(2, 4, 8),
|
attention_resolutions=(2, 4, 8),
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult=(1, 2, 4, 8),
|
channel_mult=(1, 2, 4, 8),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
dims=2,
|
dims=2,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
resblock_updown=False,
|
resblock_updown=False,
|
||||||
transformer_dim=512
|
transformer_dim=512,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
|
@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
||||||
num_heads_upsample=num_heads_upsample,
|
num_heads_upsample=num_heads_upsample,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
resblock_updown=resblock_updown,
|
resblock_updown=resblock_updown,
|
||||||
transformer_dim=transformer_dim
|
transformer_dim=transformer_dim,
|
||||||
)
|
)
|
||||||
self.register(
|
self.register(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
|
@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
||||||
num_heads_upsample=num_heads_upsample,
|
num_heads_upsample=num_heads_upsample,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
resblock_updown=resblock_updown,
|
resblock_updown=resblock_updown,
|
||||||
transformer_dim=transformer_dim
|
transformer_dim=transformer_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||||
|
@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
model_channels=192,
|
model_channels=192,
|
||||||
out_channels=6,
|
out_channels=6,
|
||||||
num_res_blocks=3,
|
num_res_blocks=3,
|
||||||
attention_resolutions=(2, 4, 8),
|
attention_resolutions=(2, 4, 8),
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult=(1, 2, 4, 8),
|
channel_mult=(1, 2, 4, 8),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
dims=2,
|
dims=2,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
resblock_updown=False,
|
resblock_updown=False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
|
@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = module(h, emb)
|
h = module(h, emb)
|
||||||
|
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
from inspect import isfunction
|
|
||||||
from abc import abstractmethod
|
|
||||||
import math
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
|
from inspect import isfunction
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from einops import repeat, rearrange
|
from einops import rearrange, repeat
|
||||||
except:
|
except:
|
||||||
print("Einops is not installed")
|
print("Einops is not installed")
|
||||||
pass
|
pass
|
||||||
|
@ -16,12 +17,13 @@ except:
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
def uniq(arr):
|
def uniq(arr):
|
||||||
return{el: True for el in arr}.keys()
|
return {el: True for el in arr}.keys()
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
|
@ -53,20 +55,13 @@ class GEGLU(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||||
nn.Linear(dim, inner_dim),
|
|
||||||
nn.GELU()
|
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
project_in,
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(inner_dim, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
hidden_dim = dim_head * heads
|
hidden_dim = dim_head * heads
|
||||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||||
k = k.softmax(dim=-1)
|
k = k.softmax(dim=-1)
|
||||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels,
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
kernel_size=1,
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
stride=1,
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
padding=0)
|
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
q = rearrange(q, "b c h w -> b (h w) c")
|
||||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
k = rearrange(k, "b c h w -> b c (h w)")
|
||||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||||
|
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c) ** (-0.5))
|
||||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
v = rearrange(v, "b c h w -> b c (h w)")
|
||||||
w_ = rearrange(w_, 'b i j -> b j i')
|
w_ = rearrange(w_, "b i j -> b j i")
|
||||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head**-0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
nn.Linear(inner_dim, query_dim),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||||
|
|
||||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
mask = rearrange(mask, "b ... -> b (...)")
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
attn = sim.softmax(dim=-1)
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is a self-attention
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
|
||||||
Then apply standard transformer action.
|
Then apply standard transformer action.
|
||||||
Finally, reshape to image
|
Finally, reshape to image
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, n_heads, d_head,
|
|
||||||
depth=1, dropout=0., context_dim=None):
|
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
|
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||||
inner_dim,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
[
|
||||||
for d in range(depth)]
|
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||||
|
for d in range(depth)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0))
|
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
|
@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
x = block(x, context=context)
|
x = block(x, context=context)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
|
|
||||||
|
|
||||||
def convert_module_to_f16(l):
|
def convert_module_to_f16(l):
|
||||||
"""
|
"""
|
||||||
Convert primitive modules to float16.
|
Convert primitive modules to float16.
|
||||||
|
@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
|
||||||
output_dim: int = None,
|
output_dim: int = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
self.num_heads = embed_dim // num_heads_channels
|
||||||
|
@ -453,9 +427,7 @@ class Upsample(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
if self.dims == 3:
|
if self.dims == 3:
|
||||||
x = F.interpolate(
|
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
|
@ -472,7 +444,7 @@ class Downsample(nn.Module):
|
||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
|
@ -480,9 +452,7 @@ class Downsample(nn.Module):
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||||
|
@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
|
||||||
normalization(self.out_channels),
|
normalization(self.out_channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
||||||
dims, channels, self.out_channels, 3, padding=1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||||
|
|
||||||
|
@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
|
||||||
# We perform two matmuls with the same number of ops.
|
# We perform two matmuls with the same number of ops.
|
||||||
# The first computes the weight matrix, the second computes
|
# The first computes the weight matrix, the second computes
|
||||||
# the combination of the value vectors.
|
# the combination of the value vectors.
|
||||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||||
model.total_ops += torch.DoubleTensor([matmul_ops])
|
model.total_ops += torch.DoubleTensor([matmul_ops])
|
||||||
|
|
||||||
|
|
||||||
|
@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
|
||||||
ch = width // (3 * self.n_heads)
|
ch = width // (3 * self.n_heads)
|
||||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
weight = torch.einsum(
|
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||||
"bct,bcs->bts", q * scale, k * scale
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||||
return a.reshape(bs, -1, length)
|
return a.reshape(bs, -1, length)
|
||||||
|
@ -773,14 +737,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
resblock_updown=False,
|
resblock_updown=False,
|
||||||
use_new_attention_order=False,
|
use_new_attention_order=False,
|
||||||
use_spatial_transformer=False, # custom transformer support
|
use_spatial_transformer=False, # custom transformer support
|
||||||
transformer_depth=1, # custom transformer support
|
transformer_depth=1, # custom transformer support
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
legacy=True,
|
legacy=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# register all __init__ params with self.register
|
# register all __init__ params with self.register
|
||||||
self.register(
|
self.register(
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
|
@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert (
|
||||||
|
context_dim is not None
|
||||||
|
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||||
|
|
||||||
if context_dim is not None:
|
if context_dim is not None:
|
||||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
assert (
|
||||||
|
use_spatial_transformer
|
||||||
|
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
if num_heads_upsample == -1:
|
||||||
num_heads_upsample = num_heads
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
if num_heads == -1:
|
if num_heads == -1:
|
||||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
|
||||||
|
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
|
||||||
|
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||||
TimestepEmbedSequential(
|
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self._feature_size = model_channels
|
self._feature_size = model_channels
|
||||||
input_block_chans = [model_channels]
|
input_block_chans = [model_channels]
|
||||||
|
@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
# num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=dim_head,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
)
|
||||||
|
if not use_spatial_transformer
|
||||||
|
else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
down=True,
|
down=True,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
|
@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
# num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=dim_head,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
)
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
if not use_spatial_transformer
|
||||||
),
|
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
|
@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads = ch // num_head_channels
|
num_heads = ch // num_head_channels
|
||||||
dim_head = num_head_channels
|
dim_head = num_head_channels
|
||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
# num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
num_heads=num_heads_upsample,
|
num_heads=num_heads_upsample,
|
||||||
num_head_channels=dim_head,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
)
|
||||||
|
if not use_spatial_transformer
|
||||||
|
else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1024,10 +990,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
def convert_to_fp16(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
|
||||||
use_new_attention_order=False,
|
use_new_attention_order=False,
|
||||||
pool="adaptive",
|
pool="adaptive",
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||||
TimestepEmbedSequential(
|
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self._feature_size = model_channels
|
self._feature_size = model_channels
|
||||||
input_block_chans = [model_channels]
|
input_block_chans = [model_channels]
|
||||||
|
@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
|
||||||
down=True,
|
down=True,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
|
@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
AttentionPool2d(
|
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
|
||||||
(image_size // ds), ch, num_head_channels, out_channels
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif pool == "spatial":
|
elif pool == "spatial":
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
|
@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
|
||||||
else:
|
else:
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,9 @@ from typing import Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from .utils import logging, DIFFUSERS_CACHE
|
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
|
from .utils import DIFFUSERS_CACHE, logging
|
||||||
|
|
||||||
|
|
||||||
INDEX_FILE = "diffusion_model.pt"
|
INDEX_FILE = "diffusion_model.pt"
|
||||||
|
@ -106,7 +105,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Add docstrings
|
Add docstrings
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
|
|
@ -11,12 +11,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||||
|
|
|
@ -11,12 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||||
|
@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
|
||||||
|
|
||||||
config_name = SAMPLING_CONFIG_NAME
|
config_name = SAMPLING_CONFIG_NAME
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
|
||||||
self,
|
|
||||||
timesteps=1000,
|
|
||||||
beta_schedule="linear",
|
|
||||||
variance_type="fixed_large"
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register(
|
self.register(
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
|
@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
|
||||||
return torch.randn(shape, generator=generator).to(device)
|
return torch.randn(shape, generator=generator).to(device)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_timesteps
|
return self.num_timesteps
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -19,7 +21,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
import os
|
|
||||||
|
|
||||||
hf_cache_home = os.path.expanduser(
|
hf_cache_home = os.path.expanduser(
|
||||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||||
|
|
|
@ -14,19 +14,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from models.vision.ddim.modeling_ddim import DDIM
|
from models.vision.ddim.modeling_ddim import DDIM
|
||||||
|
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase):
|
||||||
ConfigMixin.from_config("dummy_path")
|
ConfigMixin.from_config("dummy_path")
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
|
|
||||||
class SampleObject(ConfigMixin):
|
class SampleObject(ConfigMixin):
|
||||||
config_name = "config.json"
|
config_name = "config.json"
|
||||||
|
|
||||||
|
@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class SamplerTesterMixin(unittest.TestCase):
|
class SamplerTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sample(self):
|
def test_sample(self):
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||||
|
|
||||||
# 2. Sample gaussian noise
|
# 2. Sample gaussian noise
|
||||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
image = scheduler.sample_noise(
|
||||||
|
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Denoise
|
# 3. Denoise
|
||||||
for t in reversed(range(len(scheduler))):
|
for t in reversed(range(len(scheduler))):
|
||||||
# i) define coefficients for time step t
|
# i) define coefficients for time step t
|
||||||
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||||
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
image_coeff = (
|
||||||
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
(1 - scheduler.get_alpha_prod(t - 1))
|
||||||
|
* torch.sqrt(scheduler.get_alpha(t))
|
||||||
|
/ (1 - scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
clipped_coeff = (
|
||||||
|
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
|
||||||
# ii) predict noise residual
|
# ii) predict noise residual
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 3, 256, 256)
|
assert image.shape == (1, 3, 256, 256)
|
||||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280])
|
expected_slice = torch.tensor(
|
||||||
|
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
|
||||||
|
)
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
def test_sample_fast(self):
|
def test_sample_fast(self):
|
||||||
|
@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||||
|
|
||||||
# 2. Sample gaussian noise
|
# 2. Sample gaussian noise
|
||||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
image = scheduler.sample_noise(
|
||||||
|
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Denoise
|
# 3. Denoise
|
||||||
for t in reversed(range(len(scheduler))):
|
for t in reversed(range(len(scheduler))):
|
||||||
# i) define coefficients for time step t
|
# i) define coefficients for time step t
|
||||||
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||||
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
image_coeff = (
|
||||||
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
(1 - scheduler.get_alpha_prod(t - 1))
|
||||||
|
* torch.sqrt(scheduler.get_alpha(t))
|
||||||
|
/ (1 - scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
clipped_coeff = (
|
||||||
|
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||||
|
)
|
||||||
|
|
||||||
# ii) predict noise residual
|
# ii) predict noise residual
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class PipelineTesterMixin(unittest.TestCase):
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
# 1. Load models
|
# 1. Load models
|
||||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||||
|
@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
|
|
||||||
assert image.shape == (1, 3, 32, 32)
|
assert image.shape == (1, 3, 32, 32)
|
||||||
expected_slice = torch.tensor([-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068])
|
expected_slice = torch.tensor(
|
||||||
|
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
|
||||||
|
)
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
Loading…
Reference in New Issue