Finalize ldm (#96)

* upload

* make checkpoint work

* finalize
This commit is contained in:
Patrick von Platen 2022-07-19 02:02:23 +02:00 committed by GitHub
parent 6cabc599a2
commit d5acb4110a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 999 additions and 135 deletions

View File

@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
from .models import (
AutoencoderKL,
NCSNpp,
UNetConditionalModel,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import (
DDIMPipeline,

View File

@ -17,6 +17,7 @@
# limitations under the License.
from .unet import UNetModel
from .unet_conditional import UNetConditionalModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_ldm import UNetLDMModel
from .unet_sde_score_estimation import NCSNpp

View File

@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return self.net(x)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
@ -298,17 +307,6 @@ def default(val, d):
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.proj = nn.Conv1d(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.proj_out = nn.Conv1d(channels, channels, 1)
self.set_weights(self)
self.is_overwritten = False
@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out = nn.Conv1d(self.channels, self.channels, 1)
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data

View File

@ -0,0 +1,632 @@
import functools
import math
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock, SpatialTransformer
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class UNetConditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size=None,
in_channels=4,
out_channels=4,
num_res_blocks=2,
dropout=0,
block_channels=(320, 640, 1280, 1280),
down_blocks=(
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResDownBlock2D",
),
downsample_padding=1,
up_blocks=(
"UNetResUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=8,
flip_sin_to_cos=True,
downscale_freq_shift=0,
mid_block_scale_factor=1,
center_input_sample=False,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions=(4, 2, 1),
# DDPM
out_ch=None,
resolution=None,
attn_resolutions=None,
resamp_with_conv=None,
ch_mult=None,
ch=None,
ddpm=False,
# SDE
sde=False,
nf=None,
fir=None,
progressive=None,
progressive_combine=None,
scale_by_sigma=None,
skip_rescale=None,
num_channels=None,
centered=False,
conditional=True,
conv_size=3,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
progressive_input="input_skip",
resnet_num_groups=32,
continuous=True,
ldm=False,
):
super().__init__()
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
attention_resolutions=attention_resolutions,
attn_resolutions=attn_resolutions,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.ldm = ldm
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
# ======================================
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
self.upsample_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
# mid
self.mid = UNetMidBlock2DCrossAttn(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
)
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
is_final_block = i == len(block_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.is_overwritten = False
if ldm:
num_heads = 8
num_head_channels = -1
transformer_depth = 1
use_spatial_transformer = True
context_dim = 1280
legacy = False
model_channels = block_channels[0]
channel_mult = tuple([x // model_channels for x in block_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if not self.is_overwritten:
self.set_weights()
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
skip_sample = None
for upsample_block in self.upsample_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
return output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def set_weights(self):
self.is_overwritten = True
if self.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_norm_out.weight.data = self.out[0].weight.data
self.conv_norm_out.bias.data = self.out[0].bias.data
self.conv_out.weight.data = self.out[2].weight.data
self.conv_out.bias.data = self.out[2].bias.data
self.remove_ldm()
def init_for_ldm(
self,
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
use_spatial_transformer,
transformer_depth,
context_dim,
conv_resample,
out_channels,
):
# TODO(PVP) - delete after weight conversion
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
dims = 2
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
if dim_head < 0:
dim_head = None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
def remove_ldm(self):
del self.time_embed
del self.input_blocks
del self.middle_block
del self.output_blocks
del self.out

View File

@ -17,7 +17,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlockNew
from .attention import AttentionBlockNew, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
@ -56,6 +56,18 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResCrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
num_layers=num_layers,
@ -104,6 +116,18 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResCrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D(
num_layers=num_layers,
@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return hidden_states
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
**kwargs,
):
super().__init__()
self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class UNetResAttnDownBlock2D(nn.Module):
def __init__(
self,
@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module):
def __init__(
self,
@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetResUpBlock2D(nn.Module):
def __init__(
self,
@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]

View File

@ -1,6 +1,5 @@
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
@ -15,6 +14,69 @@ from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
################################################################################
# Code for the text transformer model
################################################################################
@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
)
sequence_output = outputs[0]
return sequence_output
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
num_trained_timesteps = self.scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# guidance_scale of 1 means no guidance
if guidance_scale == 1.0:
image_in = image
context = text_embedding
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image

View File

@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetConditionalModel,
UNetLDMModel,
UNetUnconditionalModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
from transformers import BertTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
@property
@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
@unittest.skip("Skipping for now as it takes too long")
def test_ldm_text2img(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img_fast(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
torch.manual_seed(0)
if torch.cuda.is_available():