parent
6cabc599a2
commit
d5acb4110a
|
@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
||||||
__version__ = "0.0.4"
|
__version__ = "0.0.4"
|
||||||
|
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
|
from .models import (
|
||||||
|
AutoencoderKL,
|
||||||
|
NCSNpp,
|
||||||
|
UNetConditionalModel,
|
||||||
|
UNetLDMModel,
|
||||||
|
UNetModel,
|
||||||
|
UNetUnconditionalModel,
|
||||||
|
VQModel,
|
||||||
|
)
|
||||||
from .pipeline_utils import DiffusionPipeline
|
from .pipeline_utils import DiffusionPipeline
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
DDIMPipeline,
|
DDIMPipeline,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .unet import UNetModel
|
from .unet import UNetModel
|
||||||
|
from .unet_conditional import UNetConditionalModel
|
||||||
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||||
from .unet_ldm import UNetLDMModel
|
from .unet_ldm import UNetLDMModel
|
||||||
from .unet_sde_score_estimation import NCSNpp
|
from .unet_sde_score_estimation import NCSNpp
|
||||||
|
|
|
@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
|
||||||
self.value = nn.Linear(channels, channels)
|
self.value = nn.Linear(channels, channels)
|
||||||
|
|
||||||
self.rescale_output_factor = rescale_output_factor
|
self.rescale_output_factor = rescale_output_factor
|
||||||
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
|
self.proj_attn = nn.Linear(channels, channels, 1)
|
||||||
|
|
||||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||||
|
@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.d_head = d_head
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
|
@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
|
|
||||||
|
def set_weight(self, layer):
|
||||||
|
self.norm = layer.norm
|
||||||
|
self.proj_in = layer.proj_in
|
||||||
|
self.transformer_blocks = layer.transformer_blocks
|
||||||
|
self.proj_out = layer.proj_out
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||||
|
@ -270,14 +278,15 @@ class FeedForward(nn.Module):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
# TODO(Patrick) - this can and should be removed
|
# feedforward
|
||||||
def zero_module(module):
|
class GEGLU(nn.Module):
|
||||||
"""
|
def __init__(self, dim_in, dim_out):
|
||||||
Zero out the parameters of a module and return it.
|
super().__init__()
|
||||||
"""
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
for p in module.parameters():
|
|
||||||
p.detach().zero_()
|
def forward(self, x):
|
||||||
return module
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
return x * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
|
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
|
||||||
|
@ -298,17 +307,6 @@ def default(val, d):
|
||||||
return d() if isfunction(d) else d
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
|
||||||
class GEGLU(nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
||||||
return x * F.gelu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
# the main attention block that is used for all models
|
# the main attention block that is used for all models
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
|
||||||
if encoder_channels is not None:
|
if encoder_channels is not None:
|
||||||
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
||||||
|
|
||||||
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
|
self.proj = nn.Conv1d(channels, channels, 1)
|
||||||
|
|
||||||
self.overwrite_qkv = overwrite_qkv
|
self.overwrite_qkv = overwrite_qkv
|
||||||
self.overwrite_linear = overwrite_linear
|
self.overwrite_linear = overwrite_linear
|
||||||
|
@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
||||||
else:
|
else:
|
||||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
self.proj_out = nn.Conv1d(channels, channels, 1)
|
||||||
self.set_weights(self)
|
self.set_weights(self)
|
||||||
|
|
||||||
self.is_overwritten = False
|
self.is_overwritten = False
|
||||||
|
@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
|
||||||
self.qkv.weight.data = qkv_weight
|
self.qkv.weight.data = qkv_weight
|
||||||
self.qkv.bias.data = qkv_bias
|
self.qkv.bias.data = qkv_bias
|
||||||
|
|
||||||
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
|
proj_out = nn.Conv1d(self.channels, self.channels, 1)
|
||||||
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
||||||
proj_out.bias.data = module.proj_out.bias.data
|
proj_out.bias.data = module.proj_out.bias.data
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,632 @@
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin
|
||||||
|
from ..modeling_utils import ModelMixin
|
||||||
|
from .attention import AttentionBlock, SpatialTransformer
|
||||||
|
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||||
|
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||||
|
|
||||||
|
|
||||||
|
class Combine(nn.Module):
|
||||||
|
"""Combine information from skip connections."""
|
||||||
|
|
||||||
|
def __init__(self, dim1, dim2, method="cat"):
|
||||||
|
super().__init__()
|
||||||
|
# 1x1 convolution with DDPM initialization.
|
||||||
|
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
|
||||||
|
# def forward(self, x, y):
|
||||||
|
# h = self.Conv_0(x)
|
||||||
|
# if self.method == "cat":
|
||||||
|
# return torch.cat([h, y], dim=1)
|
||||||
|
# elif self.method == "sum":
|
||||||
|
# return h + y
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Method {self.method} not recognized.")
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||||
|
self.act = None
|
||||||
|
if act_fn == "silu":
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
|
||||||
|
if self.act is not None:
|
||||||
|
sample = self.act(sample)
|
||||||
|
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
t_emb = get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||||
|
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||||
|
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||||
|
rates at which
|
||||||
|
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||||
|
downsampling, attention will be used.
|
||||||
|
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||||
|
conv_resample: if True, use learned convolutions for upsampling and
|
||||||
|
downsampling.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||||
|
model will be
|
||||||
|
class-conditional with `num_classes` classes.
|
||||||
|
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||||
|
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||||
|
a fixed channel width per attention head.
|
||||||
|
:param num_heads_upsample: works with num_heads to set a different number
|
||||||
|
of heads for upsampling. Deprecated.
|
||||||
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||||
|
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||||
|
increased efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size=None,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
num_res_blocks=2,
|
||||||
|
dropout=0,
|
||||||
|
block_channels=(320, 640, 1280, 1280),
|
||||||
|
down_blocks=(
|
||||||
|
"UNetResCrossAttnDownBlock2D",
|
||||||
|
"UNetResCrossAttnDownBlock2D",
|
||||||
|
"UNetResCrossAttnDownBlock2D",
|
||||||
|
"UNetResDownBlock2D",
|
||||||
|
),
|
||||||
|
downsample_padding=1,
|
||||||
|
up_blocks=(
|
||||||
|
"UNetResUpBlock2D",
|
||||||
|
"UNetResCrossAttnUpBlock2D",
|
||||||
|
"UNetResCrossAttnUpBlock2D",
|
||||||
|
"UNetResCrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
|
resnet_act_fn="silu",
|
||||||
|
resnet_eps=1e-5,
|
||||||
|
conv_resample=True,
|
||||||
|
num_head_channels=8,
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
downscale_freq_shift=0,
|
||||||
|
mid_block_scale_factor=1,
|
||||||
|
center_input_sample=False,
|
||||||
|
# TODO(PVP) - to delete later at release
|
||||||
|
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||||
|
# ======================================
|
||||||
|
# LDM
|
||||||
|
attention_resolutions=(4, 2, 1),
|
||||||
|
# DDPM
|
||||||
|
out_ch=None,
|
||||||
|
resolution=None,
|
||||||
|
attn_resolutions=None,
|
||||||
|
resamp_with_conv=None,
|
||||||
|
ch_mult=None,
|
||||||
|
ch=None,
|
||||||
|
ddpm=False,
|
||||||
|
# SDE
|
||||||
|
sde=False,
|
||||||
|
nf=None,
|
||||||
|
fir=None,
|
||||||
|
progressive=None,
|
||||||
|
progressive_combine=None,
|
||||||
|
scale_by_sigma=None,
|
||||||
|
skip_rescale=None,
|
||||||
|
num_channels=None,
|
||||||
|
centered=False,
|
||||||
|
conditional=True,
|
||||||
|
conv_size=3,
|
||||||
|
fir_kernel=(1, 3, 3, 1),
|
||||||
|
fourier_scale=16,
|
||||||
|
init_scale=0.0,
|
||||||
|
progressive_input="input_skip",
|
||||||
|
resnet_num_groups=32,
|
||||||
|
continuous=True,
|
||||||
|
ldm=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# register all __init__ params to be accessible via `self.config.<...>`
|
||||||
|
# should probably be automated down the road as this is pure boiler plate code
|
||||||
|
self.register_to_config(
|
||||||
|
image_size=image_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
block_channels=block_channels,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
out_channels=out_channels,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
down_blocks=down_blocks,
|
||||||
|
up_blocks=up_blocks,
|
||||||
|
dropout=dropout,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
conv_resample=conv_resample,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
flip_sin_to_cos=flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=downscale_freq_shift,
|
||||||
|
attention_resolutions=attention_resolutions,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
mid_block_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_num_groups=resnet_num_groups,
|
||||||
|
center_input_sample=center_input_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ldm = ldm
|
||||||
|
|
||||||
|
# TODO(PVP) - to delete later at release
|
||||||
|
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||||
|
# ======================================
|
||||||
|
self.image_size = image_size
|
||||||
|
time_embed_dim = block_channels[0] * 4
|
||||||
|
# ======================================
|
||||||
|
|
||||||
|
# input
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||||
|
|
||||||
|
# time
|
||||||
|
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
|
||||||
|
timestep_input_dim = block_channels[0]
|
||||||
|
|
||||||
|
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
|
||||||
|
self.downsample_blocks = nn.ModuleList([])
|
||||||
|
self.mid = None
|
||||||
|
self.upsample_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_blocks):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_channels[i]
|
||||||
|
is_final_block = i == len(block_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=num_res_blocks,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
resnet_act_fn=resnet_act_fn,
|
||||||
|
attn_num_head_channels=num_head_channels,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
)
|
||||||
|
self.downsample_blocks.append(down_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=block_channels[-1],
|
||||||
|
dropout=dropout,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
resnet_act_fn=resnet_act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
|
attn_num_head_channels=num_head_channels,
|
||||||
|
resnet_groups=resnet_num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_channels = list(reversed(block_channels))
|
||||||
|
output_channel = reversed_block_channels[0]
|
||||||
|
for i, up_block_type in enumerate(up_blocks):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_channels[i]
|
||||||
|
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
|
||||||
|
|
||||||
|
is_final_block = i == len(block_channels) - 1
|
||||||
|
|
||||||
|
up_block = get_up_block(
|
||||||
|
up_block_type,
|
||||||
|
num_layers=num_res_blocks + 1,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_upsample=not is_final_block,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
resnet_act_fn=resnet_act_fn,
|
||||||
|
attn_num_head_channels=num_head_channels,
|
||||||
|
)
|
||||||
|
self.upsample_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
# out
|
||||||
|
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
# ======================== Out ====================
|
||||||
|
|
||||||
|
# =========== TO DELETE AFTER CONVERSION ==========
|
||||||
|
# TODO(PVP) - to delete later at release
|
||||||
|
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||||
|
# ======================================
|
||||||
|
self.is_overwritten = False
|
||||||
|
if ldm:
|
||||||
|
num_heads = 8
|
||||||
|
num_head_channels = -1
|
||||||
|
transformer_depth = 1
|
||||||
|
use_spatial_transformer = True
|
||||||
|
context_dim = 1280
|
||||||
|
legacy = False
|
||||||
|
model_channels = block_channels[0]
|
||||||
|
channel_mult = tuple([x // model_channels for x in block_channels])
|
||||||
|
self.init_for_ldm(
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
channel_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
dropout,
|
||||||
|
time_embed_dim,
|
||||||
|
attention_resolutions,
|
||||||
|
num_head_channels,
|
||||||
|
num_heads,
|
||||||
|
legacy,
|
||||||
|
False,
|
||||||
|
transformer_depth,
|
||||||
|
context_dim,
|
||||||
|
conv_resample,
|
||||||
|
out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
) -> Dict[str, torch.FloatTensor]:
|
||||||
|
# TODO(PVP) - to delete later at release
|
||||||
|
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||||
|
# ======================================
|
||||||
|
if not self.is_overwritten:
|
||||||
|
self.set_weights()
|
||||||
|
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
t_emb = self.time_steps(timesteps)
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
skip_sample = sample
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.downsample_blocks:
|
||||||
|
|
||||||
|
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
skip_sample = None
|
||||||
|
for upsample_block in self.upsample_blocks:
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
output = {"sample": sample}
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
|
||||||
|
# =================================================================================================================================================
|
||||||
|
|
||||||
|
def set_weights(self):
|
||||||
|
self.is_overwritten = True
|
||||||
|
if self.ldm:
|
||||||
|
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
|
||||||
|
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
|
||||||
|
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
|
||||||
|
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
|
||||||
|
|
||||||
|
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
|
||||||
|
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
|
||||||
|
|
||||||
|
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
|
||||||
|
for i, input_layer in enumerate(self.input_blocks[1:]):
|
||||||
|
block_id = i // (self.config.num_res_blocks + 1)
|
||||||
|
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||||
|
|
||||||
|
if layer_in_block_id == 2:
|
||||||
|
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
|
||||||
|
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
|
||||||
|
elif len(input_layer) > 1:
|
||||||
|
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||||
|
else:
|
||||||
|
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
|
||||||
|
self.mid.resnets[0].set_weight(self.middle_block[0])
|
||||||
|
self.mid.resnets[1].set_weight(self.middle_block[2])
|
||||||
|
self.mid.attentions[0].set_weight(self.middle_block[1])
|
||||||
|
|
||||||
|
for i, input_layer in enumerate(self.output_blocks):
|
||||||
|
block_id = i // (self.config.num_res_blocks + 1)
|
||||||
|
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||||
|
|
||||||
|
if len(input_layer) > 2:
|
||||||
|
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||||
|
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
|
||||||
|
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
|
||||||
|
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
|
||||||
|
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
|
||||||
|
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
|
||||||
|
elif len(input_layer) > 1:
|
||||||
|
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||||
|
else:
|
||||||
|
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||||
|
|
||||||
|
self.conv_norm_out.weight.data = self.out[0].weight.data
|
||||||
|
self.conv_norm_out.bias.data = self.out[0].bias.data
|
||||||
|
self.conv_out.weight.data = self.out[2].weight.data
|
||||||
|
self.conv_out.bias.data = self.out[2].bias.data
|
||||||
|
|
||||||
|
self.remove_ldm()
|
||||||
|
|
||||||
|
def init_for_ldm(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
channel_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
dropout,
|
||||||
|
time_embed_dim,
|
||||||
|
attention_resolutions,
|
||||||
|
num_head_channels,
|
||||||
|
num_heads,
|
||||||
|
legacy,
|
||||||
|
use_spatial_transformer,
|
||||||
|
transformer_depth,
|
||||||
|
context_dim,
|
||||||
|
conv_resample,
|
||||||
|
out_channels,
|
||||||
|
):
|
||||||
|
# TODO(PVP) - delete after weight conversion
|
||||||
|
class TimestepEmbedSequential(nn.Sequential):
|
||||||
|
"""
|
||||||
|
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO(PVP) - delete after weight conversion
|
||||||
|
def conv_nd(dims, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
"""
|
||||||
|
if dims == 1:
|
||||||
|
return nn.Conv1d(*args, **kwargs)
|
||||||
|
elif dims == 2:
|
||||||
|
return nn.Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return nn.Conv3d(*args, **kwargs)
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
nn.Linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
dims = 2
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ch,
|
||||||
|
out_channels=mult * model_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="silu",
|
||||||
|
overwrite_for_ldm=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
# num_heads = 1
|
||||||
|
dim_head = num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
# num_heads = 1
|
||||||
|
dim_head = num_head_channels
|
||||||
|
|
||||||
|
if dim_head < 0:
|
||||||
|
dim_head = None
|
||||||
|
|
||||||
|
# TODO(Patrick) - delete after weight conversion
|
||||||
|
# init to be able to overwrite `self.mid`
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ch,
|
||||||
|
out_channels=None,
|
||||||
|
dropout=dropout,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="silu",
|
||||||
|
overwrite_for_ldm=True,
|
||||||
|
),
|
||||||
|
SpatialTransformer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ch,
|
||||||
|
out_channels=None,
|
||||||
|
dropout=dropout,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="silu",
|
||||||
|
overwrite_for_ldm=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||||
|
for i in range(num_res_blocks + 1):
|
||||||
|
ich = input_block_chans.pop()
|
||||||
|
layers = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ch + ich,
|
||||||
|
out_channels=model_channels * mult,
|
||||||
|
dropout=dropout,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
eps=1e-5,
|
||||||
|
non_linearity="silu",
|
||||||
|
overwrite_for_ldm=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
ch = model_channels * mult
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
# num_heads = 1
|
||||||
|
dim_head = num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if level and i == num_res_blocks:
|
||||||
|
out_ch = ch
|
||||||
|
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||||
|
ds //= 2
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv2d(model_channels, out_channels, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_ldm(self):
|
||||||
|
del self.time_embed
|
||||||
|
del self.input_blocks
|
||||||
|
del self.middle_block
|
||||||
|
del self.output_blocks
|
||||||
|
del self.out
|
|
@ -17,7 +17,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .attention import AttentionBlockNew
|
from .attention import AttentionBlockNew, SpatialTransformer
|
||||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
|
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,6 +56,18 @@ def get_down_block(
|
||||||
downsample_padding=downsample_padding,
|
downsample_padding=downsample_padding,
|
||||||
attn_num_head_channels=attn_num_head_channels,
|
attn_num_head_channels=attn_num_head_channels,
|
||||||
)
|
)
|
||||||
|
elif down_block_type == "UNetResCrossAttnDownBlock2D":
|
||||||
|
return UNetResCrossAttnDownBlock2D(
|
||||||
|
num_layers=num_layers,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
add_downsample=add_downsample,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
resnet_act_fn=resnet_act_fn,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
attn_num_head_channels=attn_num_head_channels,
|
||||||
|
)
|
||||||
elif down_block_type == "UNetResSkipDownBlock2D":
|
elif down_block_type == "UNetResSkipDownBlock2D":
|
||||||
return UNetResSkipDownBlock2D(
|
return UNetResSkipDownBlock2D(
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
|
@ -104,6 +116,18 @@ def get_up_block(
|
||||||
resnet_eps=resnet_eps,
|
resnet_eps=resnet_eps,
|
||||||
resnet_act_fn=resnet_act_fn,
|
resnet_act_fn=resnet_act_fn,
|
||||||
)
|
)
|
||||||
|
elif up_block_type == "UNetResCrossAttnUpBlock2D":
|
||||||
|
return UNetResCrossAttnUpBlock2D(
|
||||||
|
num_layers=num_layers,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
add_upsample=add_upsample,
|
||||||
|
resnet_eps=resnet_eps,
|
||||||
|
resnet_act_fn=resnet_act_fn,
|
||||||
|
attn_num_head_channels=attn_num_head_channels,
|
||||||
|
)
|
||||||
elif up_block_type == "UNetResAttnUpBlock2D":
|
elif up_block_type == "UNetResAttnUpBlock2D":
|
||||||
return UNetResAttnUpBlock2D(
|
return UNetResAttnUpBlock2D(
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
|
@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
resnet_act_fn: str = "swish",
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
|
attn_num_head_channels=1,
|
||||||
|
attention_type="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
cross_attention_dim=1280,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention_type = attention_type
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
# there is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(
|
||||||
|
SpatialTransformer(
|
||||||
|
in_channels,
|
||||||
|
attn_num_head_channels,
|
||||||
|
in_channels // attn_num_head_channels,
|
||||||
|
depth=1,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb)
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class UNetResAttnDownBlock2D(nn.Module):
|
class UNetResAttnDownBlock2D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetResCrossAttnDownBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
resnet_act_fn: str = "swish",
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
|
attn_num_head_channels=1,
|
||||||
|
cross_attention_dim=1280,
|
||||||
|
attention_type="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
downsample_padding=1,
|
||||||
|
add_downsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
self.attention_type = attention_type
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_channels = in_channels if i == 0 else out_channels
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
SpatialTransformer(
|
||||||
|
out_channels,
|
||||||
|
attn_num_head_channels,
|
||||||
|
out_channels // attn_num_head_channels,
|
||||||
|
depth=1,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_downsample:
|
||||||
|
self.downsamplers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Downsample2D(
|
||||||
|
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.downsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||||
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states)
|
||||||
|
|
||||||
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
class UNetResDownBlock2D(nn.Module):
|
class UNetResDownBlock2D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetResCrossAttnUpBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
prev_output_channel: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
resnet_act_fn: str = "swish",
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
|
attn_num_head_channels=1,
|
||||||
|
cross_attention_dim=1280,
|
||||||
|
attention_type="default",
|
||||||
|
output_scale_factor=1.0,
|
||||||
|
downsample_padding=1,
|
||||||
|
add_upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
resnets = []
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
self.attention_type = attention_type
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||||
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||||
|
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=resnet_in_channels + res_skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attentions.append(
|
||||||
|
SpatialTransformer(
|
||||||
|
out_channels,
|
||||||
|
attn_num_head_channels,
|
||||||
|
out_channels // attn_num_head_channels,
|
||||||
|
depth=1,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
if add_upsample:
|
||||||
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||||
|
else:
|
||||||
|
self.upsamplers = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class UNetResUpBlock2D(nn.Module):
|
class UNetResUpBlock2D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
||||||
output_states = ()
|
|
||||||
|
|
||||||
for resnet in self.resnets:
|
for resnet in self.resnets:
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
||||||
output_states = ()
|
|
||||||
|
|
||||||
for resnet in self.resnets:
|
for resnet in self.resnets:
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -15,6 +14,69 @@ from transformers.utils import logging
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||||
|
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
||||||
|
super().__init__()
|
||||||
|
scheduler = scheduler.set_format("pt")
|
||||||
|
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
batch_size=1,
|
||||||
|
generator=None,
|
||||||
|
torch_device=None,
|
||||||
|
eta=0.0,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
num_inference_steps=50,
|
||||||
|
):
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
|
|
||||||
|
if torch_device is None:
|
||||||
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.unet.to(torch_device)
|
||||||
|
self.vqvae.to(torch_device)
|
||||||
|
self.bert.to(torch_device)
|
||||||
|
|
||||||
|
# get unconditional embeddings for classifier free guidence
|
||||||
|
if guidance_scale != 1.0:
|
||||||
|
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||||
|
|
||||||
|
# get text embedding
|
||||||
|
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||||
|
text_embedding = self.bert(text_input.input_ids)
|
||||||
|
|
||||||
|
image = torch.randn(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||||
|
generator=generator,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||||
|
# 1. predict noise residual
|
||||||
|
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
|
||||||
|
|
||||||
|
if isinstance(pred_noise_t, dict):
|
||||||
|
pred_noise_t = pred_noise_t["sample"]
|
||||||
|
|
||||||
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
|
# do x_t -> x_t-1
|
||||||
|
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
|
||||||
|
|
||||||
|
# scale and decode image with vae
|
||||||
|
image = 1 / 0.18215 * image
|
||||||
|
image = self.vqvae.decode(image)
|
||||||
|
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Code for the text transformer model
|
# Code for the text transformer model
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -541,101 +603,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
return sequence_output
|
return sequence_output
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
|
||||||
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
|
||||||
super().__init__()
|
|
||||||
scheduler = scheduler.set_format("pt")
|
|
||||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
batch_size=1,
|
|
||||||
generator=None,
|
|
||||||
torch_device=None,
|
|
||||||
eta=0.0,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
num_inference_steps=50,
|
|
||||||
):
|
|
||||||
# eta corresponds to η in paper and should be between [0, 1]
|
|
||||||
|
|
||||||
if torch_device is None:
|
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
self.unet.to(torch_device)
|
|
||||||
self.vqvae.to(torch_device)
|
|
||||||
self.bert.to(torch_device)
|
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidence
|
|
||||||
if guidance_scale != 1.0:
|
|
||||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
|
||||||
|
|
||||||
# get text embedding
|
|
||||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
|
||||||
text_embedding = self.bert(text_input.input_ids)
|
|
||||||
|
|
||||||
num_trained_timesteps = self.scheduler.config.timesteps
|
|
||||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
|
||||||
|
|
||||||
image = torch.randn(
|
|
||||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
|
||||||
generator=generator,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
|
||||||
# Ideally, read DDIM paper in-detail understanding
|
|
||||||
|
|
||||||
# Notation (<variable name> -> <name in paper>
|
|
||||||
# - pred_noise_t -> e_theta(x_t, t)
|
|
||||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
|
||||||
# - std_dev_t -> sigma_t
|
|
||||||
# - eta -> η
|
|
||||||
# - pred_image_direction -> "direction pointingc to x_t"
|
|
||||||
# - pred_prev_image -> "x_t-1"
|
|
||||||
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
|
||||||
# guidance_scale of 1 means no guidance
|
|
||||||
if guidance_scale == 1.0:
|
|
||||||
image_in = image
|
|
||||||
context = text_embedding
|
|
||||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
|
||||||
else:
|
|
||||||
# for classifier free guidance, we need to do two forward passes
|
|
||||||
# here we concanate embedding and unconditioned embedding in a single batch
|
|
||||||
# to avoid doing two forward passes
|
|
||||||
image_in = torch.cat([image] * 2)
|
|
||||||
context = torch.cat([uncond_embeddings, text_embedding])
|
|
||||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
|
||||||
|
|
||||||
# 1. predict noise residual
|
|
||||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if guidance_scale != 1.0:
|
|
||||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
|
||||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
|
||||||
|
|
||||||
# 2. predict previous mean of image x_t-1
|
|
||||||
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
|
||||||
|
|
||||||
# 3. optionally sample variance
|
|
||||||
variance = 0
|
|
||||||
if eta > 0:
|
|
||||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
|
||||||
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
|
||||||
|
|
||||||
# 4. set current image to prev_image: x_t -> x_t-1
|
|
||||||
image = pred_prev_image + variance
|
|
||||||
|
|
||||||
# scale and decode image with vae
|
|
||||||
image = 1 / 0.18215 * image
|
|
||||||
image = self.vqvae.decode(image)
|
|
||||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
return image
|
|
|
@ -40,14 +40,17 @@ from diffusers import (
|
||||||
ScoreSdeVeScheduler,
|
ScoreSdeVeScheduler,
|
||||||
ScoreSdeVpPipeline,
|
ScoreSdeVpPipeline,
|
||||||
ScoreSdeVpScheduler,
|
ScoreSdeVpScheduler,
|
||||||
|
UNetConditionalModel,
|
||||||
UNetLDMModel,
|
UNetLDMModel,
|
||||||
UNetUnconditionalModel,
|
UNetUnconditionalModel,
|
||||||
VQModel,
|
VQModel,
|
||||||
)
|
)
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
|
||||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||||
|
|
||||||
|
|
||||||
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||||
model_class = AutoencoderKL
|
model_class = AutoencoderKL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@unittest.skip("Skipping for now as it takes too long")
|
|
||||||
def test_ldm_text2img(self):
|
def test_ldm_text2img(self):
|
||||||
model_id = "fusing/latent-diffusion-text2im-large"
|
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
|
||||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_ldm_text2img_fast(self):
|
def test_ldm_text2img_fast(self):
|
||||||
model_id = "fusing/latent-diffusion-text2im-large"
|
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
|
||||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_score_sde_ve_pipeline(self):
|
def test_score_sde_ve_pipeline(self):
|
||||||
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
|
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
|
||||||
|
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
Loading…
Reference in New Issue