Get diffusers ready 🚀🚀🚀 (#101)

* big purge

* more fixes

* finish for now
This commit is contained in:
Patrick von Platen 2022-07-19 18:02:12 +02:00 committed by GitHub
parent 33344ed916
commit 8c31925b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 65 additions and 4533 deletions

View File

@ -1,110 +0,0 @@
#!/usr/bin/env python3
import json
import os
from regex import P
from diffusers import UNetUnconditionalModel
from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint
from huggingface_hub import hf_hub_download
import torch
def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"):
if subfolder is not None:
checkpoint = os.path.join(subfolder, checkpoint)
config = os.path.join(subfolder, config)
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu')
config_path = hf_hub_download(model_id, config)
with open(config_path) as f:
config = json.load(f)
checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config)
def current_codebase_conversion(path):
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True)
model.eval()
model.config.sde=False
model.save_config(path)
model.config.sde=True
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
return model.state_dict()
path = f"{model_id}_converted"
currently_converted_checkpoint = current_codebase_conversion(path)
def diff_between_checkpoints(ch_0, ch_1):
all_layers_included = False
if not set(ch_0.keys()) == set(ch_1.keys()):
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
print(f"\t{key}")
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
print(f"\t{key}")
else:
print("Keys are the same between the two checkpoints")
all_layers_included = True
keys = ch_0.keys()
non_equal_keys = []
if all_layers_included:
for key in keys:
try:
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
except RuntimeError as e:
print(e)
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
if len(non_equal_keys):
non_equal_keys = '\n\t'.join(non_equal_keys)
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
else:
print("All keys are equal across checkpoints.")
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
os.makedirs( f"{model_id}_converted",exist_ok =True)
torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt")
model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve",
"fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy"
]
for model in model_ids:
print(f"converting {model}")
try:
convert_checkpoint(model)
except Exception as e:
print(e)
from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests
tester1 = NCSNppModelTests()
tester2 = PipelineTesterMixin()
os.environ["RUN_SLOW"] = '1'
cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command
os.system(cmd)
tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted")
tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted")
tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted")

153
run.py
View File

@ -1,153 +0,0 @@
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NewReverseDiffusionPredictor:
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow
self.score_fn = score_fn
def discretize(self, x, t):
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def update_fn(self, x, t):
f, G = self.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
class NewLangevinCorrector:
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__()
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t):
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
# timestep = (t * (sde.N - 1) / sde.T).long()
# alpha = sde.alphas.to(t.device)[timestep]
# else:
alpha = torch.ones_like(t)
for i in range(n_steps):
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x, x_mean
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
N = 2
sigma_min = 0.01
sigma_max = 1348
sampling_eps = 1e-5
batch_size = 1
centered = False
from diffusers import NCSNpp
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
model = torch.nn.DataParallel(model)
img_size = model.module.config.image_size
channels = model.module.config.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
snr = 0.15
n_steps = 1
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
with torch.no_grad():
# Initial sample
x = torch.randn(*shape) * sigma_max
x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x = x_mean
if centered:
x = (x + 1.) / 2.
# save_image(x)
# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758
# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504
# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)

View File

@ -1,113 +0,0 @@
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512,
intermediate_size=2048,
num_hidden_layers=16,
num_attention_heads=8,
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx]
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the Text-to-Image UNet
text2im_model = GlideTextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
text2im_model.load_state_dict(state_dict, strict=False)
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GlideSuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = Glide(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")

View File

@ -7,36 +7,13 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import (
AutoencoderKL,
NCSNpp,
UNetConditionalModel,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import (
DDIMPipeline,
DDPMPipeline,
LatentDiffusionUncondPipeline,
PNDMPipeline,
ScoreSdeVePipeline,
ScoreSdeVpPipeline,
)
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
if is_transformers_available():
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .pipelines import GlidePipeline, LatentDiffusionPipeline
from .pipelines import LatentDiffusionPipeline
else:
from .utils.dummy_transformers_objects import *

View File

@ -16,10 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNetModel
from .unet_conditional import UNetConditionalModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_ldm import UNetLDMModel
from .unet_sde_score_estimation import NCSNpp
from .unet_unconditional import UNetUnconditionalModel
from .vae import AutoencoderKL, VQModel

View File

@ -54,6 +54,43 @@ def get_timestep_embedding(
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""

View File

@ -1,195 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class UNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
ch=128,
out_ch=3,
ch_mult=(1, 1, 2, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=256,
):
super().__init__()
self.register_to_config(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=in_channels,
resolution=resolution,
)
ch_mult = tuple(ch_mult)
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock2D(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, sample, timesteps):
x = sample
assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding
temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = self.mid_new(hs[-1], temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h

View File

@ -1,74 +1,12 @@
import functools
import math
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock, SpatialTransformer
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class UNetConditionalModel(ModelMixin, ConfigMixin):
@ -124,38 +62,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
downscale_freq_shift=0,
mid_block_scale_factor=1,
center_input_sample=False,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions=(4, 2, 1),
# DDPM
out_ch=None,
resolution=None,
attn_resolutions=None,
resamp_with_conv=None,
ch_mult=None,
ch=None,
ddpm=False,
# SDE
sde=False,
nf=None,
fir=None,
progressive=None,
progressive_combine=None,
scale_by_sigma=None,
skip_rescale=None,
num_channels=None,
centered=False,
conditional=True,
conv_size=3,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
progressive_input="input_skip",
resnet_num_groups=32,
continuous=True,
ldm=False,
resnet_num_groups=30,
):
super().__init__()
# register all __init__ params to be accessible via `self.config.<...>`
@ -175,21 +82,13 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
attention_resolutions=attention_resolutions,
attn_resolutions=attn_resolutions,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.ldm = ldm
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
# ======================================
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
@ -264,57 +163,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel
# out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.is_overwritten = False
if ldm:
num_heads = 8
num_head_channels = -1
transformer_depth = 1
use_spatial_transformer = True
context_dim = 1280
legacy = False
model_channels = block_channels[0]
channel_mult = tuple([x // model_channels for x in block_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if not self.is_overwritten:
self.set_weights()
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@ -329,7 +189,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
emb = self.time_embedding(t_emb)
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
@ -349,7 +208,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
skip_sample = None
for upsample_block in self.upsample_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
@ -374,259 +232,3 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
output = {"sample": sample}
return output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def set_weights(self):
self.is_overwritten = True
if self.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_norm_out.weight.data = self.out[0].weight.data
self.conv_norm_out.bias.data = self.out[0].bias.data
self.conv_out.weight.data = self.out[2].weight.data
self.conv_out.bias.data = self.out[2].bias.data
self.remove_ldm()
def init_for_ldm(
self,
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
use_spatial_transformer,
transformer_depth,
context_dim,
conv_resample,
out_channels,
):
# TODO(PVP) - delete after weight conversion
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
dims = 2
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
if dim_head < 0:
dim_head = None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
def remove_ldm(self):
del self.time_embed
del self.input_blocks
del self.middle_block
del self.output_blocks
del self.out

View File

@ -1,554 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
else:
x = layer(x)
return x
class GlideUNetModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param
out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling.
"""
def __init__(
self,
in_channels=3,
resolution=64,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=None,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.resolution = resolution
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
# self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=out_ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
down=True,
)
if resblock_updown
else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.mid = UNetMidBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=1e-5,
resnet_act_fn="silu",
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
attn_num_heads=num_heads,
attn_num_head_channels=num_head_channels,
attn_encoder_channels=transformer_dim,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
),
ResnetBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
),
)
self.mid.resnets[0] = self.middle_block[0]
self.mid.attentions[0] = self.middle_block[1]
self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
),
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
encoder_channels=transformer_dim,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResnetBlock2D(
in_channels=ch,
out_channels=out_ch,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True,
up=True,
)
if resblock_updown
else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch, swish=1.0),
nn.Identity(),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
)
self.use_fp16 = use_fp16
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N]
Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
"""
hs = []
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.mid(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class GlideTextToImageUNetModel(GlideUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
resolution=64,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512,
):
super().__init__(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim,
)
self.register_to_config(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim,
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, sample, timestep, transformer_out=None):
timesteps = timestep
x = sample
hs = []
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1])
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
emb = emb + transformer_proj.to(emb)
h = x
for module in self.input_blocks:
h = module(h, emb, transformer_out)
hs.append(h)
h = self.mid(h, emb, transformer_out)
for module in self.output_blocks:
other = hs.pop()
h = torch.cat([h, other], dim=1)
h = module(h, emb, transformer_out)
return self.out(h)
class GlideSuperResUNetModel(GlideUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
resolution=256,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register_to_config(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, sample, timestep, low_res=None):
timesteps = timestep
x = sample
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
hs = []
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.mid(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)

View File

@ -1,627 +0,0 @@
import math
from inspect import isfunction
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
# from .resnet import ResBlock
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
macs, params = thop.profile(
model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += torch.DoubleTensor([matmul_ops])
class UNetLDMModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
num_classes=num_classes,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_spatial_transformer=use_spatial_transformer,
transformer_depth=transformer_depth,
context_dim=context_dim,
n_embed=n_embed,
legacy=legacy,
)
if use_spatial_transformer:
assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self.down_in_conv = self.input_blocks[0][0]
self.downsample_blocks = nn.ModuleList([])
self.upsample_blocks = nn.ModuleList([])
# ========================= Down (OLD) =================== #
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=dim_head,
)
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])]
output_channels = [model_channels * mult for mult in channel_mult]
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if dim_head < 0:
dim_head = None
# ========================= MID (New) =================== #
self.mid = UNetMidBlock2D(
in_channels=ch,
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=1e-5,
resnet_act_fn="silu",
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
attention_layer_type="self" if not use_spatial_transformer else "spatial",
attn_num_heads=num_heads,
attn_num_head_channels=dim_head,
attn_depth=transformer_depth,
attn_encoder_channels=context_dim,
)
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=dim_head,
)
if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self.mid.resnets[0] = self.middle_block[0]
self.mid.attentions[0] = self.middle_block[1]
self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch
# ========================= Up (Old) =================== #
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
)
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if
class-conditional. :return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype_)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.mid(h, emb, context)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
return self.out(h)
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)

View File

@ -1,471 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import functools
import math
import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
def forward(self, x, y):
h = self.Conv_0(x)
if self.method == "cat":
return torch.cat([h, y], dim=1)
elif self.method == "sum":
return h + y
else:
raise ValueError(f"Method {self.method} not recognized.")
class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model"""
def __init__(
self,
image_size=1024,
num_channels=3,
centered=False,
attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True,
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
nf=16,
num_res_blocks=1,
progressive="output_skip",
progressive_combine="sum",
progressive_input="input_skip",
resamp_with_conv=True,
scale_by_sigma=True,
skip_rescale=True,
continuous=True,
):
super().__init__()
self.register_to_config(
image_size=image_size,
num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions,
ch_mult=ch_mult,
conditional=conditional,
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
nf=nf,
num_res_blocks=num_res_blocks,
progressive=progressive,
progressive_combine=progressive_combine,
progressive_input=progressive_input,
resamp_with_conv=resamp_with_conv,
scale_by_sigma=scale_by_sigma,
skip_rescale=skip_rescale,
continuous=continuous,
)
self.act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
self.conditional = conditional
self.skip_rescale = skip_rescale
self.progressive = progressive
self.progressive_input = progressive_input
self.embedding_type = embedding_type
assert progressive in ["none", "output_skip", "residual"]
assert progressive_input in ["none", "input_skip", "residual"]
assert embedding_type in ["fourier", "positional"]
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == "fourier":
# Gaussian Fourier features embeddings.
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
embed_dim = 2 * nf
elif embedding_type == "positional":
embed_dim = nf
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir:
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir:
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir" if self.fir else "sde_vp",
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == "cat":
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
# mid
self.mid = UNetMidBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=math.sqrt(2.0),
resnet_act_fn="silu",
resnet_groups=min(in_ch // 4, 32),
dropout=dropout,
)
in_ch = hs_c[-1]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
# self.mid.resnets[0] = modules[len(modules) - 3]
# self.mid.attentions[0] = modules[len(modules) - 2]
# self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != "none":
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
# elif progressive == "residual":
# modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
# modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
# elif progressive == "residual":
# modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir" if self.fir else "sde_vp",
use_nin_shortcut=True,
)
)
assert not hs_c
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules)
def forward(self, sample, timestep, sigmas=None):
timesteps = timestep
x = sample
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings.
used_sigmas = timesteps
temb = modules[m_idx](used_sigmas)
m_idx += 1
elif self.embedding_type == "positional":
# Sinusoidal positional embeddings.
timesteps = timesteps
used_sigmas = sigmas
temb = get_timestep_embedding(timesteps, self.nf)
else:
raise ValueError(f"embedding type {self.embedding_type} unknown.")
if self.conditional:
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
else:
temb = None
# If input data is in [0, 1]
if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block
input_pyramid = None
if self.progressive_input != "none":
input_pyramid = x
hs = [modules[m_idx](x)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
hs.append(h)
if i_level != self.num_resolutions - 1:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == "input_skip":
input_pyramid = self.pyramid_downsample(input_pyramid)
h = modules[m_idx](input_pyramid, h)
m_idx += 1
elif self.progressive_input == "residual":
input_pyramid = modules[m_idx](input_pyramid)
m_idx += 1
if self.skip_rescale:
input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
else:
input_pyramid = input_pyramid + h
h = input_pyramid
hs.append(h)
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
pyramid = None
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if self.progressive != "none":
if i_level == self.num_resolutions - 1:
if self.progressive == "output_skip":
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
# elif self.progressive == "residual":
# pyramid = self.act(modules[m_idx](h))
# m_idx += 1
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# else:
# raise ValueError(f"{self.progressive} is not a valid name.")
else:
if self.progressive == "output_skip":
pyramid_h = self.act(modules[m_idx](h))
m_idx += 1
pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1
skip_sample = self.pyramid_upsample(pyramid)
pyramid = skip_sample + pyramid_h
# elif self.progressive == "residual":
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# if self.skip_rescale:
# pyramid = (pyramid + h) / np.sqrt(2.0)
# else:
# pyramid = pyramid + h
# h = pyramid
# else:
# raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0:
h = modules[m_idx](h, temb)
m_idx += 1
assert not hs
if self.progressive == "output_skip":
h = pyramid
else:
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules)
if self.config.scale_by_sigma:
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
h = h / used_sigmas
return h

File diff suppressed because it is too large Load Diff

View File

@ -4,9 +4,7 @@ from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .score_sde_vp import ScoreSdeVpPipeline
if is_transformers_available():
from .glide import GlidePipeline
from .latent_diffusion import LatentDiffusionPipeline

View File

@ -1,5 +0,0 @@
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_glide import CLIPTextModel, GlidePipeline

View File

@ -1,843 +0,0 @@
# coding=utf-8
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CLIP model."""
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from tqdm.auto import tqdm
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import logging
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
#####################
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "fusing/glide-base"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"fusing/glide-base",
# See all CLIP models at https://huggingface.co/models?filter=clip
]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
return (caption_loss + image_loss) / 2.0
@dataclass
class CLIPOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
self.use_padding_embeddings = config.use_padding_embeddings
if self.use_padding_embeddings:
self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
if self.use_padding_embeddings and attention_mask is not None:
padding_embeddings = self.padding_embedding(position_ids)
embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings)
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
wdtype = attn_weights.dtype
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class CLIPMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(self, config: CLIPConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.mlp = CLIPMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class CLIPPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = CLIPConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, CLIPTextEmbeddings):
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
if hasattr(module, "padding_embedding"):
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
elif isinstance(module, CLIPVisionEmbeddings):
factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, CLIPAttention):
factor = self.config.initializer_factor
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, CLIPMLP):
factor = self.config.initializer_factor
in_proj_std = (
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
)
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, CLIPModel):
nn.init.normal_(
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CLIPEncoder):
module.gradient_checkpointing = value
CLIP_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self, config: CLIPConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
causal_attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
# causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=None,
causal_attention_mask=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class CLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import CLIPTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
#####################
# END OF THE CLIP MODEL COPY-PASTE
#####################
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param
broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GlidePipeline(DiffusionPipeline):
def __init__(
self,
text_unet: GlideTextToImageUNetModel,
text_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GlideSuperResUNetModel,
upscale_scheduler: DDIMScheduler,
):
super().__init__()
self.register_modules(
text_unet=text_unet,
text_scheduler=text_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_scheduler=upscale_scheduler,
)
@torch.no_grad()
def __call__(
self,
prompt,
generator=None,
torch_device=None,
num_inference_steps_upscale=50,
guidance_scale=3.0,
eta=0.0,
upsample_temp=0.997,
):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.text_unet.to(torch_device)
self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = torch.randn(
(
batch_size,
self.text_unet.in_channels,
self.text_unet.resolution,
self.text_unet.resolution,
),
generator=generator,
).to(torch_device)
# 2. Encode tokens
# an empty input is needed to guide the model away from it
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# 3. Run the text2image generation step
num_prediction_steps = len(self.text_scheduler)
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = text_model_fn(image, time_input, transformer_out)
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
min_log = self.text_scheduler.get_variance(t, "fixed_small_log")
max_log = self.text_scheduler.get_variance(t, "fixed_large_log")
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
pred_prev_image = self.text_scheduler.step(noise_residual, image, t)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = torch.exp(0.5 * model_log_variance) * noise
# set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
# Sample gaussian noise to begin loop
image = torch.randn(
(
batch_size,
self.upscale_unet.in_channels // 2,
self.upscale_unet.resolution,
self.upscale_unet.resolution,
),
generator=generator,
).to(torch_device)
image = image * upsample_temp
num_trained_timesteps = self.upscale_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
for t in tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual
with torch.no_grad():
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
return image

View File

@ -1 +0,0 @@
from .pipeline_score_sde_vp import ScoreSdeVpPipeline

View File

@ -1,38 +0,0 @@
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVpPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=1000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = self.model.to(device)
x = torch.randn(*shape).to(device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)
# TODO add corrector
with torch.no_grad():
result = model(x, scaled_t)
x, x_mean = self.scheduler.step_pred(result, x, t)
x_mean = (x_mean + 1.0) / 2.0
return x_mean

View File

@ -255,8 +255,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"ch": 32,
"ch_mult": (1, 2),
"block_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
@ -264,8 +262,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 3,
"in_channels": 3,
"num_res_blocks": 2,
"attn_resolutions": (16,),
"resolution": 32,
"image_size": 32,
}
inputs_dict = self.dummy_input
@ -322,13 +318,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"in_channels": 4,
"out_channels": 4,
"num_res_blocks": 2,
"attention_resolutions": (16,),
"block_channels": (32, 64),
"num_head_channels": 32,
"conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
"ldm": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@ -529,8 +523,8 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
"ch": 64,
"out_ch": 3,
"num_res_blocks": 1,
"attn_resolutions": [],
"in_channels": 3,
"attn_resolutions": [],
"resolution": 32,
"z_channels": 3,
"n_embed": 256,
@ -605,11 +599,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
"ch_mult": (1,),
"embed_dim": 4,
"in_channels": 3,
"attn_resolutions": [],
"num_res_blocks": 1,
"out_ch": 3,
"resolution": 32,
"z_channels": 4,
"attn_resolutions": [],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@ -655,7 +649,6 @@ class PipelineTesterMixin(unittest.TestCase):
model = UNetUnconditionalModel(
block_channels=(32, 64),
num_res_blocks=2,
attn_resolutions=(16,),
image_size=32,
in_channels=3,
out_channels=3,