Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
commit
74d2da9950
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from modeling_glide import GLIDE
|
||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||
|
||||
|
@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
|
|||
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
||||
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
||||
|
||||
### Convert the UNet
|
||||
### Convert the Text-to-Image UNet
|
||||
|
||||
unet_model = UNetGLIDEModel(
|
||||
text2im_model = GLIDETextToImageUNetModel(
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
|
@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
|
|||
transformer_dim=512,
|
||||
)
|
||||
|
||||
unet_model.load_state_dict(state_dict, strict=False)
|
||||
text2im_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
||||
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
||||
|
||||
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
|
||||
### Convert the Super-Resolution UNet
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
|
||||
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
|
||||
|
||||
superres_model = GLIDESuperResUNetModel(
|
||||
in_channels=6,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(8, 16, 32),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 1, 2, 2, 4, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
)
|
||||
|
||||
superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||
|
||||
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
|
||||
|
||||
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
|
||||
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
|
||||
|
||||
glide.save_pretrained("./glide-base")
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import tqdm
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
|
@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|||
class GLIDE(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNetGLIDEModel,
|
||||
noise_scheduler: ClassifierFreeGuidanceScheduler,
|
||||
text_unet: GLIDETextToImageUNetModel,
|
||||
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GLIDESuperResUNetModel,
|
||||
upscale_noise_scheduler: GlideDDIMScheduler
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
|
||||
|
@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
|
|||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
|
||||
scheduler.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
|
@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
|
|||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
|
||||
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
|
|||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
model_output = model(x, t, transformer_out)
|
||||
if transformer_out is None:
|
||||
# super-res model
|
||||
model_output = model(x, t, low_res)
|
||||
else:
|
||||
# text2image model
|
||||
model_output = model(x, t, transformer_out)
|
||||
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
|
||||
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = torch.exp(model_log_variance)
|
||||
|
||||
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
||||
if clip_denoised:
|
||||
pred_xstart = pred_xstart.clamp(-1, 1)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return model_mean, model_variance, model_log_variance, pred_xstart
|
||||
|
||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, generator=None, torch_device=None):
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
self.upscale_unet.to(torch_device)
|
||||
|
||||
# Create a classifier-free guidance sampling function
|
||||
guidance_scale = 3.0
|
||||
|
||||
def model_fn(x_t, ts, transformer_out, **kwargs):
|
||||
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.unet(combined, ts, transformer_out, **kwargs)
|
||||
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
|
@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
|
|||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
image = self.text_noise_scheduler.sample_noise(
|
||||
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# 2. Encode tokens
|
||||
|
@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
|
|||
attention_mask = inputs["attention_mask"].to(torch_device)
|
||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
num_timesteps = len(self.noise_scheduler)
|
||||
# 3. Run the text2image generation step
|
||||
num_timesteps = len(self.text_noise_scheduler)
|
||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out)
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
||||
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
||||
)
|
||||
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
||||
eta = 0.0
|
||||
|
||||
# Tune this parameter to control the sharpness of 256x256 images.
|
||||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
||||
upsample_temp = 0.997
|
||||
|
||||
image = self.upscale_noise_scheduler.sample_noise(
|
||||
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
||||
) * upsample_temp
|
||||
|
||||
num_timesteps = len(self.upscale_noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
|
||||
# i) define coefficients for time step t
|
||||
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
|
||||
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
|
||||
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
||||
|
||||
# ii) predict noise residual
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = self.upscale_unet(image, time_input, low_res)
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
|
||||
generator=generator)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
image = image[0].permute(1, 2, 0)
|
||||
|
||||
return image
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
import torch
|
||||
|
||||
from modeling_glide import GLIDE
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
matplotlib.rcParams['interactive'] = True
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
import PIL.Image
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
|
||||
# 1. Load models
|
||||
pipeline = GLIDE.from_pretrained("fusing/glide-base")
|
||||
model_id = "fusing/glide-base"
|
||||
|
||||
img = pipeline("an oil painting of a corgi", generator)
|
||||
# load model and scheduler
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
# run inference (text-conditioned denoising + upscaling)
|
||||
img = pipeline("a clip art of a hugging face", generator)
|
||||
|
||||
# process image to PIL
|
||||
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
image_pil = PIL.Image.fromarray(img)
|
||||
|
||||
plt.figure(figsize=(8, 8))
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
# save image
|
||||
image_pil.save("test.png")
|
|
@ -7,9 +7,10 @@ __version__ = "0.0.1"
|
|||
from .modeling_utils import ModelMixin
|
||||
from .models.clip_text_transformer import CLIPTextModel
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import UNetGLIDEModel
|
||||
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.vqvae import VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .schedulers.glide_ddim import GlideDDIMScheduler
|
||||
|
|
|
@ -18,6 +18,6 @@
|
|||
|
||||
from .clip_text_transformer import CLIPTextModel
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import UNetGLIDEModel
|
||||
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .vqvae import VQModel
|
|
@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
|
|||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
||||
class GLIDEUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
|
@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
|
@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=512,
|
||||
transformer_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
@ -482,8 +463,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
||||
ch = input_ch = int(channel_mult[0] * model_channels)
|
||||
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
|
||||
self._feature_size = ch
|
||||
|
@ -635,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps, transformer_out):
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
|
@ -644,6 +623,91 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=512
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim
|
||||
)
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim
|
||||
)
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
||||
def forward(self, x, timesteps, transformer_out=None):
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
|
@ -663,3 +727,86 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
h = torch.cat([h, other], dim=1)
|
||||
h = module(h, emb, transformer_out)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class GLIDESuperResUNetModel(GLIDEUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None):
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
x = torch.cat([x, upsampled], dim=1)
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
||||
return self.out(h)
|
|
@ -39,9 +39,10 @@ LOADABLE_CLASSES = {
|
|||
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
|
||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
|
||||
"GlideDDIMScheduler": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -18,3 +18,4 @@
|
|||
|
||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .glide_ddim import GlideDDIMScheduler
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
|||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
class GlideDDIMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
variance_type="fixed_small",
|
||||
variance_type="fixed_large"
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
variance_type=variance_type,
|
||||
)
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / self.num_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
|
@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
|||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_timesteps
|
||||
return self.num_timesteps
|
Loading…
Reference in New Issue