Get diffusers ready 🚀🚀🚀 (#101)
* big purge * more fixes * finish for now
This commit is contained in:
parent
33344ed916
commit
8c31925b3b
|
@ -1,110 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
|
||||
from regex import P
|
||||
from diffusers import UNetUnconditionalModel
|
||||
from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"):
|
||||
if subfolder is not None:
|
||||
checkpoint = os.path.join(subfolder, checkpoint)
|
||||
config = os.path.join(subfolder, config)
|
||||
|
||||
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu')
|
||||
config_path = hf_hub_download(model_id, config)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config)
|
||||
|
||||
|
||||
def current_codebase_conversion(path):
|
||||
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True)
|
||||
model.eval()
|
||||
model.config.sde=False
|
||||
model.save_config(path)
|
||||
model.config.sde=True
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
return model.state_dict()
|
||||
|
||||
path = f"{model_id}_converted"
|
||||
currently_converted_checkpoint = current_codebase_conversion(path)
|
||||
|
||||
|
||||
def diff_between_checkpoints(ch_0, ch_1):
|
||||
all_layers_included = False
|
||||
|
||||
if not set(ch_0.keys()) == set(ch_1.keys()):
|
||||
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
|
||||
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
|
||||
print(f"\t{key}")
|
||||
|
||||
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
|
||||
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
|
||||
print(f"\t{key}")
|
||||
else:
|
||||
print("Keys are the same between the two checkpoints")
|
||||
all_layers_included = True
|
||||
|
||||
keys = ch_0.keys()
|
||||
non_equal_keys = []
|
||||
|
||||
if all_layers_included:
|
||||
for key in keys:
|
||||
try:
|
||||
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
|
||||
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
|
||||
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
|
||||
|
||||
if len(non_equal_keys):
|
||||
non_equal_keys = '\n\t'.join(non_equal_keys)
|
||||
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
|
||||
else:
|
||||
print("All keys are equal across checkpoints.")
|
||||
|
||||
|
||||
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
|
||||
os.makedirs( f"{model_id}_converted",exist_ok =True)
|
||||
torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt")
|
||||
|
||||
|
||||
model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve",
|
||||
"fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy"
|
||||
]
|
||||
for model in model_ids:
|
||||
print(f"converting {model}")
|
||||
try:
|
||||
convert_checkpoint(model)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests
|
||||
|
||||
tester1 = NCSNppModelTests()
|
||||
tester2 = PipelineTesterMixin()
|
||||
|
||||
os.environ["RUN_SLOW"] = '1'
|
||||
cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command
|
||||
os.system(cmd)
|
||||
tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted")
|
||||
tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted")
|
||||
tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted")
|
153
run.py
153
run.py
|
@ -1,153 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
#from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
|
||||
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class NewReverseDiffusionPredictor:
|
||||
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
|
||||
super().__init__()
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.N = N
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
||||
|
||||
self.probability_flow = probability_flow
|
||||
self.score_fn = score_fn
|
||||
|
||||
def discretize(self, x, t):
|
||||
timestep = (t * (self.N - 1)).long()
|
||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
||||
self.discrete_sigmas[timestep - 1].to(t.device))
|
||||
f = torch.zeros_like(x)
|
||||
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
||||
|
||||
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
result = self.score_fn(x, labels)
|
||||
|
||||
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
|
||||
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
||||
return rev_f, rev_G
|
||||
|
||||
def update_fn(self, x, t):
|
||||
f, G = self.discretize(x, t)
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
class NewLangevinCorrector:
|
||||
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
|
||||
super().__init__()
|
||||
self.score_fn = score_fn
|
||||
self.snr = snr
|
||||
self.n_steps = n_steps
|
||||
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
|
||||
def update_fn(self, x, t):
|
||||
score_fn = self.score_fn
|
||||
n_steps = self.n_steps
|
||||
target_snr = self.snr
|
||||
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
|
||||
# timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
# alpha = sde.alphas.to(t.device)[timestep]
|
||||
# else:
|
||||
alpha = torch.ones_like(t)
|
||||
|
||||
for i in range(n_steps):
|
||||
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
grad = score_fn(x, labels)
|
||||
noise = torch.randn_like(x)
|
||||
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
||||
x_mean = x + step_size[:, None, None, None] * grad
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
|
||||
def save_image(x):
|
||||
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("../images/hey.png")
|
||||
|
||||
|
||||
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||
# Note usually we need to restore ema etc...
|
||||
# ema restored checkpoint used from below
|
||||
|
||||
N = 2
|
||||
sigma_min = 0.01
|
||||
sigma_max = 1348
|
||||
sampling_eps = 1e-5
|
||||
batch_size = 1
|
||||
centered = False
|
||||
|
||||
from diffusers import NCSNpp
|
||||
|
||||
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
img_size = model.module.config.image_size
|
||||
channels = model.module.config.num_channels
|
||||
shape = (batch_size, channels, img_size, img_size)
|
||||
probability_flow = False
|
||||
snr = 0.15
|
||||
n_steps = 1
|
||||
|
||||
|
||||
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
|
||||
|
||||
with torch.no_grad():
|
||||
# Initial sample
|
||||
x = torch.randn(*shape) * sigma_max
|
||||
x = x.to(device)
|
||||
timesteps = torch.linspace(1, sampling_eps, N, device=device)
|
||||
|
||||
for i in range(N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
|
||||
x = x_mean
|
||||
if centered:
|
||||
x = (x + 1.) / 2.
|
||||
|
||||
|
||||
# save_image(x)
|
||||
|
||||
# for 5 cifar10
|
||||
x_sum = 106071.9922
|
||||
x_mean = 34.52864456176758
|
||||
|
||||
# for 1000 cifar10
|
||||
x_sum = 461.9700
|
||||
x_mean = 0.1504
|
||||
|
||||
# for 2 for 1024
|
||||
x_sum = 3382810112.0
|
||||
x_mean = 1075.366455078125
|
||||
|
||||
def check_x_sum_x_mean(x, x_sum, x_mean):
|
||||
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
|
||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
|
@ -1,113 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
|
||||
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
|
||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
|
||||
state_dict = torch.load("base.pt", map_location="cpu")
|
||||
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
|
||||
|
||||
### Convert the text encoder
|
||||
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=50257,
|
||||
max_position_embeddings=128,
|
||||
hidden_size=512,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=8,
|
||||
use_padding_embeddings=True,
|
||||
)
|
||||
model = CLIPTextModel(config).eval()
|
||||
tokenizer = GPT2Tokenizer(
|
||||
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
|
||||
)
|
||||
|
||||
hf_encoder = model.text_model
|
||||
|
||||
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
|
||||
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
|
||||
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
|
||||
|
||||
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
|
||||
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
|
||||
|
||||
for layer_idx in range(config.num_hidden_layers):
|
||||
hf_layer = hf_encoder.encoder.layers[layer_idx]
|
||||
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
|
||||
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
|
||||
|
||||
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
|
||||
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
|
||||
|
||||
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
|
||||
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
|
||||
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
|
||||
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
|
||||
|
||||
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
|
||||
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
|
||||
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
||||
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
||||
|
||||
### Convert the Text-to-Image UNet
|
||||
|
||||
text2im_model = GlideTextToImageUNetModel(
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 2, 3, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
transformer_dim=512,
|
||||
)
|
||||
|
||||
text2im_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
||||
|
||||
### Convert the Super-Resolution UNet
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
|
||||
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
|
||||
|
||||
superres_model = GlideSuperResUNetModel(
|
||||
in_channels=6,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(8, 16, 32),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 1, 2, 2, 4, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
)
|
||||
|
||||
superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||
|
||||
upscale_scheduler = DDIMScheduler(
|
||||
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
|
||||
)
|
||||
|
||||
glide = Glide(
|
||||
text_unet=text2im_model,
|
||||
text_noise_scheduler=text_scheduler,
|
||||
text_encoder=model,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=superres_model,
|
||||
upscale_noise_scheduler=upscale_scheduler,
|
||||
)
|
||||
|
||||
glide.save_pretrained("./glide-base")
|
|
@ -7,36 +7,13 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
|||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import (
|
||||
AutoencoderKL,
|
||||
NCSNpp,
|
||||
UNetConditionalModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
UNetUnconditionalModel,
|
||||
VQModel,
|
||||
)
|
||||
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import (
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
LatentDiffusionUncondPipeline,
|
||||
PNDMPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVpPipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpScheduler,
|
||||
)
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .pipelines import GlidePipeline, LatentDiffusionPipeline
|
||||
from .pipelines import LatentDiffusionPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
|
|
@ -16,10 +16,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .unet import UNetModel
|
||||
from .unet_conditional import UNetConditionalModel
|
||||
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_sde_score_estimation import NCSNpp
|
||||
from .unet_unconditional import UNetUnconditionalModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
|
|
|
@ -54,6 +54,43 @@ def get_timestep_embedding(
|
|||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.act = None
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian Fourier embeddings for noise levels."""
|
||||
|
||||
|
|
|
@ -1,195 +0,0 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# helpers functions
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class UNetModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 1, 2, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=256,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
)
|
||||
ch_mult = tuple(ch_mult)
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
self.mid_new.resnets[0] = self.mid.block_1
|
||||
self.mid_new.attentions[0] = self.mid.attn_1
|
||||
self.mid_new.resnets[1] = self.mid.block_2
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, sample, timesteps):
|
||||
x = sample
|
||||
assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||
|
||||
# timestep embedding
|
||||
temb = get_timestep_embedding(timesteps, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = self.mid_new(hs[-1], temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
|
@ -1,74 +1,12 @@
|
|||
import functools
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
|
||||
# def forward(self, x, y):
|
||||
# h = self.Conv_0(x)
|
||||
# if self.method == "cat":
|
||||
# return torch.cat([h, y], dim=1)
|
||||
# elif self.method == "sum":
|
||||
# return h + y
|
||||
# else:
|
||||
# raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.act = None
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||
|
@ -124,38 +62,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
downscale_freq_shift=0,
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
# LDM
|
||||
attention_resolutions=(4, 2, 1),
|
||||
# DDPM
|
||||
out_ch=None,
|
||||
resolution=None,
|
||||
attn_resolutions=None,
|
||||
resamp_with_conv=None,
|
||||
ch_mult=None,
|
||||
ch=None,
|
||||
ddpm=False,
|
||||
# SDE
|
||||
sde=False,
|
||||
nf=None,
|
||||
fir=None,
|
||||
progressive=None,
|
||||
progressive_combine=None,
|
||||
scale_by_sigma=None,
|
||||
skip_rescale=None,
|
||||
num_channels=None,
|
||||
centered=False,
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
progressive_input="input_skip",
|
||||
resnet_num_groups=32,
|
||||
continuous=True,
|
||||
ldm=False,
|
||||
resnet_num_groups=30,
|
||||
):
|
||||
super().__init__()
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
|
@ -175,21 +82,13 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
num_head_channels=num_head_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
downscale_freq_shift=downscale_freq_shift,
|
||||
attention_resolutions=attention_resolutions,
|
||||
attn_resolutions=attn_resolutions,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
resnet_num_groups=resnet_num_groups,
|
||||
center_input_sample=center_input_sample,
|
||||
)
|
||||
|
||||
self.ldm = ldm
|
||||
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
# ======================================
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
@ -264,57 +163,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
# ======================== Out ====================
|
||||
|
||||
# =========== TO DELETE AFTER CONVERSION ==========
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
self.is_overwritten = False
|
||||
if ldm:
|
||||
num_heads = 8
|
||||
num_head_channels = -1
|
||||
transformer_depth = 1
|
||||
use_spatial_transformer = True
|
||||
context_dim = 1280
|
||||
legacy = False
|
||||
model_channels = block_channels[0]
|
||||
channel_mult = tuple([x // model_channels for x in block_channels])
|
||||
self.init_for_ldm(
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
False,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
if not self.is_overwritten:
|
||||
self.set_weights()
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
|
@ -329,7 +189,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
|
@ -349,7 +208,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.upsample_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
|
@ -374,259 +232,3 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
|
||||
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
|
||||
# =================================================================================================================================================
|
||||
|
||||
def set_weights(self):
|
||||
self.is_overwritten = True
|
||||
if self.ldm:
|
||||
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
|
||||
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
|
||||
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
|
||||
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
|
||||
|
||||
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
|
||||
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
|
||||
|
||||
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
|
||||
for i, input_layer in enumerate(self.input_blocks[1:]):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if layer_in_block_id == 2:
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.mid.resnets[0].set_weight(self.middle_block[0])
|
||||
self.mid.resnets[1].set_weight(self.middle_block[2])
|
||||
self.mid.attentions[0].set_weight(self.middle_block[1])
|
||||
|
||||
for i, input_layer in enumerate(self.output_blocks):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if len(input_layer) > 2:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
|
||||
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.conv_norm_out.weight.data = self.out[0].weight.data
|
||||
self.conv_norm_out.bias.data = self.out[0].bias.data
|
||||
self.conv_out.weight.data = self.out[2].weight.data
|
||||
self.conv_out.bias.data = self.out[2].bias.data
|
||||
|
||||
self.remove_ldm()
|
||||
|
||||
def init_for_ldm(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
use_spatial_transformer,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
):
|
||||
# TODO(PVP) - delete after weight conversion
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
# TODO(PVP) - delete after weight conversion
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
dims = 2
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
),
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
|
||||
if dim_head < 0:
|
||||
dim_head = None
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(model_channels, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def remove_ldm(self):
|
||||
del self.time_embed
|
||||
del self.input_blocks
|
||||
del self.middle_block
|
||||
del self.output_blocks
|
||||
del self.out
|
||||
|
|
|
@ -1,554 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
|
||||
def convert_module_to_f32(l):
|
||||
"""
|
||||
Convert primitive modules to float32, undoing convert_module_to_f16().
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.float()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.float()
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
|
||||
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
|
||||
self.swish = swish
|
||||
|
||||
def forward(self, x):
|
||||
y = super().forward(x.float()).to(x.dtype)
|
||||
if self.swish == 1.0:
|
||||
y = F.silu(y)
|
||||
elif self.swish:
|
||||
y = y * F.sigmoid(y * float(self.swish))
|
||||
return y
|
||||
|
||||
|
||||
def normalization(channels, swish=0.0):
|
||||
"""
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels. :return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, encoder_out=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, AttentionBlock):
|
||||
x = layer(x, encoder_out)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
:param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param
|
||||
out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
resolution=64,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.resolution = resolution
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
# self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
ch = input_ch = int(channel_mult[0] * model_channels)
|
||||
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
)
|
||||
]
|
||||
ch = int(mult * model_channels)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="silu",
|
||||
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
|
||||
attn_num_heads=num_heads,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
attn_encoder_channels=transformer_dim,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnets[0] = self.middle_block[0]
|
||||
self.mid.attentions[0] = self.middle_block[1]
|
||||
self.mid.resnets[1] = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
]
|
||||
ch = int(model_channels * mult)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=out_ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
|
||||
overwrite_for_glide=True,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch, swish=1.0),
|
||||
nn.Identity(),
|
||||
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
||||
)
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N]
|
||||
Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.mid(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class GlideTextToImageUNetModel(GlideUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
resolution=64,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
transformer_dim=512,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
||||
def forward(self, sample, timestep, transformer_out=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
hs = []
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
# project the last token
|
||||
transformer_proj = self.transformer_proj(transformer_out[:, -1])
|
||||
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
|
||||
|
||||
emb = emb + transformer_proj.to(emb)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, transformer_out)
|
||||
hs.append(h)
|
||||
h = self.mid(h, emb, transformer_out)
|
||||
for module in self.output_blocks:
|
||||
other = hs.pop()
|
||||
h = torch.cat([h, other], dim=1)
|
||||
h = module(h, emb, transformer_out)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class GlideSuperResUNetModel(GlideUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
resolution=256,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
|
||||
def forward(self, sample, timestep, low_res=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
x = torch.cat([x, upsampled], dim=1)
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.mid(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
||||
return self.out(h)
|
|
@ -1,627 +0,0 @@
|
|||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
# from .resnet import ResBlock
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
|
||||
def convert_module_to_f32(l):
|
||||
"""
|
||||
Convert primitive modules to float32, undoing convert_module_to_f16().
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.float()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.float()
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
|
||||
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
|
||||
self.swish = swish
|
||||
|
||||
def forward(self, x):
|
||||
y = super().forward(x.float()).to(x.dtype)
|
||||
if self.swish == 1.0:
|
||||
y = F.silu(y)
|
||||
elif self.swish:
|
||||
y = y * F.sigmoid(y * float(self.swish))
|
||||
return y
|
||||
|
||||
|
||||
def normalization(channels, swish=0.0):
|
||||
"""
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels. :return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += torch.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||
rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
num_classes=num_classes,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
use_spatial_transformer=use_spatial_transformer,
|
||||
transformer_depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
n_embed=n_embed,
|
||||
legacy=legacy,
|
||||
)
|
||||
|
||||
if use_spatial_transformer:
|
||||
assert (
|
||||
context_dim is not None
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.dtype_ = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
|
||||
self.down_in_conv = self.input_blocks[0][0]
|
||||
self.downsample_blocks = nn.ModuleList([])
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
|
||||
# ========================= Down (OLD) =================== #
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])]
|
||||
output_channels = [model_channels * mult for mult in channel_mult]
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
|
||||
if dim_head < 0:
|
||||
dim_head = None
|
||||
|
||||
# ========================= MID (New) =================== #
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="silu",
|
||||
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
|
||||
attention_layer_type="self" if not use_spatial_transformer else "spatial",
|
||||
attn_num_heads=num_heads,
|
||||
attn_num_head_channels=dim_head,
|
||||
attn_depth=transformer_depth,
|
||||
attn_encoder_channels=context_dim,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnets[0] = self.middle_block[0]
|
||||
self.mid.attentions[0] = self.middle_block[1]
|
||||
self.mid.resnets[1] = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
# ========================= Up (Old) =================== #
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
|
||||
of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if
|
||||
class-conditional. :return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype_)
|
||||
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = self.mid(h, emb, context)
|
||||
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = x.shape
|
||||
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
return self.to_out(out)
|
|
@ -1,471 +0,0 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# helpers functions
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.Conv_0(x)
|
||||
if self.method == "cat":
|
||||
return torch.cat([h, y], dim=1)
|
||||
elif self.method == "sum":
|
||||
return h + y
|
||||
else:
|
||||
raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class NCSNpp(ModelMixin, ConfigMixin):
|
||||
"""NCSN++ model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=1024,
|
||||
num_channels=3,
|
||||
centered=False,
|
||||
attn_resolutions=(16,),
|
||||
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
dropout=0.0,
|
||||
embedding_type="fourier",
|
||||
fir=True,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
nf=16,
|
||||
num_res_blocks=1,
|
||||
progressive="output_skip",
|
||||
progressive_combine="sum",
|
||||
progressive_input="input_skip",
|
||||
resamp_with_conv=True,
|
||||
scale_by_sigma=True,
|
||||
skip_rescale=True,
|
||||
continuous=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
num_channels=num_channels,
|
||||
centered=centered,
|
||||
attn_resolutions=attn_resolutions,
|
||||
ch_mult=ch_mult,
|
||||
conditional=conditional,
|
||||
conv_size=conv_size,
|
||||
dropout=dropout,
|
||||
embedding_type=embedding_type,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
nf=nf,
|
||||
num_res_blocks=num_res_blocks,
|
||||
progressive=progressive,
|
||||
progressive_combine=progressive_combine,
|
||||
progressive_input=progressive_input,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
scale_by_sigma=scale_by_sigma,
|
||||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = nn.SiLU()
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
|
||||
|
||||
self.conditional = conditional
|
||||
self.skip_rescale = skip_rescale
|
||||
self.progressive = progressive
|
||||
self.progressive_input = progressive_input
|
||||
self.embedding_type = embedding_type
|
||||
assert progressive in ["none", "output_skip", "residual"]
|
||||
assert progressive_input in ["none", "input_skip", "residual"]
|
||||
assert embedding_type in ["fourier", "positional"]
|
||||
combine_method = progressive_combine.lower()
|
||||
combiner = functools.partial(Combine, method=combine_method)
|
||||
|
||||
modules = []
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
if embedding_type == "fourier":
|
||||
# Gaussian Fourier features embeddings.
|
||||
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
|
||||
embed_dim = 2 * nf
|
||||
|
||||
elif embedding_type == "positional":
|
||||
embed_dim = nf
|
||||
|
||||
else:
|
||||
raise ValueError(f"embedding type {embedding_type} unknown.")
|
||||
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
|
||||
if self.fir:
|
||||
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||
|
||||
if self.fir:
|
||||
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
|
||||
|
||||
channels = num_channels
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
for i_level in range(self.num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
down=True,
|
||||
kernel="fir" if self.fir else "sde_vp",
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
if combine_method == "cat":
|
||||
in_ch *= 2
|
||||
|
||||
elif progressive_input == "residual":
|
||||
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
|
||||
input_pyramid_ch = in_ch
|
||||
|
||||
hs_c.append(in_ch)
|
||||
|
||||
# mid
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=math.sqrt(2.0),
|
||||
resnet_act_fn="silu",
|
||||
resnet_groups=min(in_ch // 4, 32),
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
# self.mid.resnets[0] = modules[len(modules) - 3]
|
||||
# self.mid.attentions[0] = modules[len(modules) - 2]
|
||||
# self.mid.resnets[1] = modules[len(modules) - 1]
|
||||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
in_ch = in_ch + hs_c.pop()
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
|
||||
if progressive != "none":
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
# elif progressive == "residual":
|
||||
# modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
# modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
# pyramid_ch = in_ch
|
||||
# else:
|
||||
# raise ValueError(f"{progressive} is not a valid name.")
|
||||
else:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
# elif progressive == "residual":
|
||||
# modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
|
||||
# pyramid_ch = in_ch
|
||||
# else:
|
||||
# raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
up=True,
|
||||
kernel="fir" if self.fir else "sde_vp",
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert not hs_c
|
||||
|
||||
if progressive != "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, sample, timestep, sigmas=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
if self.embedding_type == "fourier":
|
||||
# Gaussian Fourier features embeddings.
|
||||
used_sigmas = timesteps
|
||||
temb = modules[m_idx](used_sigmas)
|
||||
m_idx += 1
|
||||
|
||||
elif self.embedding_type == "positional":
|
||||
# Sinusoidal positional embeddings.
|
||||
timesteps = timesteps
|
||||
used_sigmas = sigmas
|
||||
temb = get_timestep_embedding(timesteps, self.nf)
|
||||
|
||||
else:
|
||||
raise ValueError(f"embedding type {self.embedding_type} unknown.")
|
||||
|
||||
if self.conditional:
|
||||
temb = modules[m_idx](temb)
|
||||
m_idx += 1
|
||||
temb = modules[m_idx](self.act(temb))
|
||||
m_idx += 1
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# If input data is in [0, 1]
|
||||
if not self.config.centered:
|
||||
x = 2 * x - 1.0
|
||||
|
||||
# Downsampling block
|
||||
input_pyramid = None
|
||||
if self.progressive_input != "none":
|
||||
input_pyramid = x
|
||||
|
||||
hs = [modules[m_idx](x)]
|
||||
m_idx += 1
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
|
||||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
|
||||
if self.progressive_input == "input_skip":
|
||||
input_pyramid = self.pyramid_downsample(input_pyramid)
|
||||
h = modules[m_idx](input_pyramid, h)
|
||||
m_idx += 1
|
||||
|
||||
elif self.progressive_input == "residual":
|
||||
input_pyramid = modules[m_idx](input_pyramid)
|
||||
m_idx += 1
|
||||
if self.skip_rescale:
|
||||
input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
|
||||
else:
|
||||
input_pyramid = input_pyramid + h
|
||||
h = input_pyramid
|
||||
|
||||
hs.append(h)
|
||||
|
||||
h = hs[-1]
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
pyramid = None
|
||||
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
m_idx += 1
|
||||
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
|
||||
if self.progressive != "none":
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if self.progressive == "output_skip":
|
||||
pyramid = self.act(modules[m_idx](h))
|
||||
m_idx += 1
|
||||
pyramid = modules[m_idx](pyramid)
|
||||
m_idx += 1
|
||||
# elif self.progressive == "residual":
|
||||
# pyramid = self.act(modules[m_idx](h))
|
||||
# m_idx += 1
|
||||
# pyramid = modules[m_idx](pyramid)
|
||||
# m_idx += 1
|
||||
# else:
|
||||
# raise ValueError(f"{self.progressive} is not a valid name.")
|
||||
else:
|
||||
if self.progressive == "output_skip":
|
||||
pyramid_h = self.act(modules[m_idx](h))
|
||||
m_idx += 1
|
||||
pyramid_h = modules[m_idx](pyramid_h)
|
||||
m_idx += 1
|
||||
|
||||
skip_sample = self.pyramid_upsample(pyramid)
|
||||
pyramid = skip_sample + pyramid_h
|
||||
# elif self.progressive == "residual":
|
||||
# pyramid = modules[m_idx](pyramid)
|
||||
# m_idx += 1
|
||||
# if self.skip_rescale:
|
||||
# pyramid = (pyramid + h) / np.sqrt(2.0)
|
||||
# else:
|
||||
# pyramid = pyramid + h
|
||||
# h = pyramid
|
||||
# else:
|
||||
# raise ValueError(f"{self.progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
|
||||
if self.progressive == "output_skip":
|
||||
h = pyramid
|
||||
else:
|
||||
h = self.act(modules[m_idx](h))
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
|
||||
assert m_idx == len(modules)
|
||||
if self.config.scale_by_sigma:
|
||||
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
||||
h = h / used_sigmas
|
||||
|
||||
return h
|
File diff suppressed because it is too large
Load Diff
|
@ -4,9 +4,7 @@ from .ddpm import DDPMPipeline
|
|||
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .score_sde_vp import ScoreSdeVpPipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .glide import GlidePipeline
|
||||
from .latent_diffusion import LatentDiffusionPipeline
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from ...utils import is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_glide import CLIPTextModel, GlidePipeline
|
|
@ -1,843 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch CLIP model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
|
||||
from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
#####################
|
||||
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
|
||||
#####################
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "fusing/glide-base"
|
||||
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"fusing/glide-base",
|
||||
# See all CLIP models at https://huggingface.co/models?filter=clip
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# contrastive loss function, adapted from
|
||||
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
||||
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
||||
|
||||
|
||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||
caption_loss = contrastive_loss(similarity)
|
||||
image_loss = contrastive_loss(similarity.T)
|
||||
return (caption_loss + image_loss) / 2.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||
Contrastive loss for image-text similarity.
|
||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||
similarity scores.
|
||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||
similarity scores.
|
||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||
text_model_output(`BaseModelOutputWithPooling`):
|
||||
The output of the [`CLIPTextModel`].
|
||||
vision_model_output(`BaseModelOutputWithPooling`):
|
||||
The output of the [`CLIPVisionModel`].
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits_per_image: torch.FloatTensor = None
|
||||
logits_per_text: torch.FloatTensor = None
|
||||
text_embeds: torch.FloatTensor = None
|
||||
image_embeds: torch.FloatTensor = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPooling = None
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
return tuple(
|
||||
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||||
for k in self.keys()
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
self.use_padding_embeddings = config.use_padding_embeddings
|
||||
if self.use_padding_embeddings:
|
||||
self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
if self.use_padding_embeddings and attention_mask is not None:
|
||||
padding_embeddings = self.padding_embedding(position_ids)
|
||||
embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
|
||||
|
||||
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
|
||||
|
||||
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
|
||||
|
||||
wdtype = attn_weights.dtype
|
||||
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
|
||||
|
||||
attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
|
||||
self.mlp = CLIPMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CLIPPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = CLIPConfig
|
||||
base_model_prefix = "clip"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPTextEmbeddings):
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
if hasattr(module, "padding_embedding"):
|
||||
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
||||
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
||||
elif isinstance(module, CLIPAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
out_proj_std = (module.embed_dim**-0.5) * factor
|
||||
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
||||
elif isinstance(module, CLIPMLP):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (
|
||||
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
)
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
elif isinstance(module, CLIPModel):
|
||||
nn.init.normal_(
|
||||
module.text_projection.weight,
|
||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
nn.init.normal_(
|
||||
module.visual_projection.weight,
|
||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, CLIPEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
# causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=None,
|
||||
causal_attention_mask=None,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_causal_attention_mask(self, bsz, seq_len):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len)
|
||||
mask.fill_(torch.tensor(float("-inf")))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.embeddings.token_embedding = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
#####################
|
||||
# END OF THE CLIP MODEL COPY-PASTE
|
||||
#####################
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param
|
||||
broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
||||
|
||||
|
||||
class GlidePipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
text_unet: GlideTextToImageUNetModel,
|
||||
text_scheduler: DDPMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GlideSuperResUNetModel,
|
||||
upscale_scheduler: DDIMScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
text_unet=text_unet,
|
||||
text_scheduler=text_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet,
|
||||
upscale_scheduler=upscale_scheduler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
num_inference_steps_upscale=50,
|
||||
guidance_scale=3.0,
|
||||
eta=0.0,
|
||||
upsample_temp=0.997,
|
||||
):
|
||||
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.text_unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
self.upscale_unet.to(torch_device)
|
||||
|
||||
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
self.text_unet.in_channels,
|
||||
self.text_unet.resolution,
|
||||
self.text_unet.resolution,
|
||||
),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from it
|
||||
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
attention_mask = inputs["attention_mask"].to(torch_device)
|
||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
# 3. Run the text2image generation step
|
||||
num_prediction_steps = len(self.text_scheduler)
|
||||
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = text_model_fn(image, time_input, transformer_out)
|
||||
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
||||
|
||||
min_log = self.text_scheduler.get_variance(t, "fixed_small_log")
|
||||
max_log = self.text_scheduler.get_variance(t, "fixed_large_log")
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
pred_prev_image = self.text_scheduler.step(noise_residual, image, t)
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
variance = torch.exp(0.5 * model_log_variance) * noise
|
||||
|
||||
# set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
self.upscale_unet.in_channels // 2,
|
||||
self.upscale_unet.resolution,
|
||||
self.upscale_unet.resolution,
|
||||
),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
image = image * upsample_temp
|
||||
|
||||
num_trained_timesteps = self.upscale_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
|
||||
for t in tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
model_output = self.upscale_unet(image, time_input, low_res)
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.upscale_scheduler.step(
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
||||
)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
|
||||
|
||||
return image
|
|
@ -1 +0,0 @@
|
|||
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
|
|
@ -1,38 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
|
||||
class ScoreSdeVpPipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
def __call__(self, num_inference_steps=1000, generator=None):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
channels = self.model.config.num_channels
|
||||
shape = (1, channels, img_size, img_size)
|
||||
|
||||
model = self.model.to(device)
|
||||
|
||||
x = torch.randn(*shape).to(device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.scheduler.timesteps:
|
||||
t = t * torch.ones(shape[0], device=device)
|
||||
scaled_t = t * (num_inference_steps - 1)
|
||||
|
||||
# TODO add corrector
|
||||
with torch.no_grad():
|
||||
result = model(x, scaled_t)
|
||||
|
||||
x, x_mean = self.scheduler.step_pred(result, x, t)
|
||||
|
||||
x_mean = (x_mean + 1.0) / 2.0
|
||||
|
||||
return x_mean
|
|
@ -255,8 +255,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 32,
|
||||
"ch_mult": (1, 2),
|
||||
"block_channels": (32, 64),
|
||||
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
|
||||
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
|
||||
|
@ -264,8 +262,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
"out_channels": 3,
|
||||
"in_channels": 3,
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": (16,),
|
||||
"resolution": 32,
|
||||
"image_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
|
@ -322,13 +318,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_res_blocks": 2,
|
||||
"attention_resolutions": (16,),
|
||||
"block_channels": (32, 64),
|
||||
"num_head_channels": 32,
|
||||
"conv_resample": True,
|
||||
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
|
||||
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
|
||||
"ldm": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
@ -529,8 +523,8 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
"ch": 64,
|
||||
"out_ch": 3,
|
||||
"num_res_blocks": 1,
|
||||
"attn_resolutions": [],
|
||||
"in_channels": 3,
|
||||
"attn_resolutions": [],
|
||||
"resolution": 32,
|
||||
"z_channels": 3,
|
||||
"n_embed": 256,
|
||||
|
@ -605,11 +599,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
|||
"ch_mult": (1,),
|
||||
"embed_dim": 4,
|
||||
"in_channels": 3,
|
||||
"attn_resolutions": [],
|
||||
"num_res_blocks": 1,
|
||||
"out_ch": 3,
|
||||
"resolution": 32,
|
||||
"z_channels": 4,
|
||||
"attn_resolutions": [],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
@ -655,7 +649,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
model = UNetUnconditionalModel(
|
||||
block_channels=(32, 64),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
image_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
|
|
Loading…
Reference in New Issue