parent
6cabc599a2
commit
d5acb4110a
|
@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
|||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
|
||||
from .models import (
|
||||
AutoencoderKL,
|
||||
NCSNpp,
|
||||
UNetConditionalModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
UNetUnconditionalModel,
|
||||
VQModel,
|
||||
)
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import (
|
||||
DDIMPipeline,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .unet import UNetModel
|
||||
from .unet_conditional import UNetConditionalModel
|
||||
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_sde_score_estimation import NCSNpp
|
||||
|
|
|
@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
|
|||
self.value = nn.Linear(channels, channels)
|
||||
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
|
||||
self.proj_attn = nn.Linear(channels, channels, 1)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
|
@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
|
|||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
|
@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
|
|||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
def set_weight(self, layer):
|
||||
self.norm = layer.norm
|
||||
self.proj_in = layer.proj_in
|
||||
self.transformer_blocks = layer.transformer_blocks
|
||||
self.proj_out = layer.proj_out
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
|
@ -270,14 +278,15 @@ class FeedForward(nn.Module):
|
|||
return self.net(x)
|
||||
|
||||
|
||||
# TODO(Patrick) - this can and should be removed
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
|
||||
|
@ -298,17 +307,6 @@ def default(val, d):
|
|||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
# the main attention block that is used for all models
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
|
@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
|
|||
if encoder_channels is not None:
|
||||
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
||||
|
||||
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
self.proj = nn.Conv1d(channels, channels, 1)
|
||||
|
||||
self.overwrite_qkv = overwrite_qkv
|
||||
self.overwrite_linear = overwrite_linear
|
||||
|
@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
self.proj_out = nn.Conv1d(channels, channels, 1)
|
||||
self.set_weights(self)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
|
|||
self.qkv.weight.data = qkv_weight
|
||||
self.qkv.bias.data = qkv_bias
|
||||
|
||||
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
|
||||
proj_out = nn.Conv1d(self.channels, self.channels, 1)
|
||||
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
||||
proj_out.bias.data = module.proj_out.bias.data
|
||||
|
||||
|
|
|
@ -0,0 +1,632 @@
|
|||
import functools
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
"""Combine information from skip connections."""
|
||||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
|
||||
# def forward(self, x, y):
|
||||
# h = self.Conv_0(x)
|
||||
# if self.method == "cat":
|
||||
# return torch.cat([h, y], dim=1)
|
||||
# elif self.method == "sum":
|
||||
# return h + y
|
||||
# else:
|
||||
# raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.act = None
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||
rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=None,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_res_blocks=2,
|
||||
dropout=0,
|
||||
block_channels=(320, 640, 1280, 1280),
|
||||
down_blocks=(
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
),
|
||||
downsample_padding=1,
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
),
|
||||
resnet_act_fn="silu",
|
||||
resnet_eps=1e-5,
|
||||
conv_resample=True,
|
||||
num_head_channels=8,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0,
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
# LDM
|
||||
attention_resolutions=(4, 2, 1),
|
||||
# DDPM
|
||||
out_ch=None,
|
||||
resolution=None,
|
||||
attn_resolutions=None,
|
||||
resamp_with_conv=None,
|
||||
ch_mult=None,
|
||||
ch=None,
|
||||
ddpm=False,
|
||||
# SDE
|
||||
sde=False,
|
||||
nf=None,
|
||||
fir=None,
|
||||
progressive=None,
|
||||
progressive_combine=None,
|
||||
scale_by_sigma=None,
|
||||
skip_rescale=None,
|
||||
num_channels=None,
|
||||
centered=False,
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
progressive_input="input_skip",
|
||||
resnet_num_groups=32,
|
||||
continuous=True,
|
||||
ldm=False,
|
||||
):
|
||||
super().__init__()
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
# should probably be automated down the road as this is pure boiler plate code
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
block_channels=block_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
down_blocks=down_blocks,
|
||||
up_blocks=up_blocks,
|
||||
dropout=dropout,
|
||||
resnet_eps=resnet_eps,
|
||||
conv_resample=conv_resample,
|
||||
num_head_channels=num_head_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
downscale_freq_shift=downscale_freq_shift,
|
||||
attention_resolutions=attention_resolutions,
|
||||
attn_resolutions=attn_resolutions,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
resnet_num_groups=resnet_num_groups,
|
||||
center_input_sample=center_input_sample,
|
||||
)
|
||||
|
||||
self.ldm = ldm
|
||||
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
# ======================================
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
|
||||
timestep_input_dim = block_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.downsample_blocks = nn.ModuleList([])
|
||||
self.mid = None
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_channels[0]
|
||||
for i, down_block_type in enumerate(down_blocks):
|
||||
input_channel = output_channel
|
||||
output_channel = block_channels[i]
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=num_res_blocks,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.downsample_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
resnet_groups=resnet_num_groups,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_channels = list(reversed(block_channels))
|
||||
output_channel = reversed_block_channels[0]
|
||||
for i, up_block_type in enumerate(up_blocks):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_channels[i]
|
||||
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=num_res_blocks + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
self.upsample_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
# ======================== Out ====================
|
||||
|
||||
# =========== TO DELETE AFTER CONVERSION ==========
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
self.is_overwritten = False
|
||||
if ldm:
|
||||
num_heads = 8
|
||||
num_head_channels = -1
|
||||
transformer_depth = 1
|
||||
use_spatial_transformer = True
|
||||
context_dim = 1280
|
||||
legacy = False
|
||||
model_channels = block_channels[0]
|
||||
channel_mult = tuple([x // model_channels for x in block_channels])
|
||||
self.init_for_ldm(
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
False,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
# ======================================
|
||||
if not self.is_overwritten:
|
||||
self.set_weights()
|
||||
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_steps(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.downsample_blocks:
|
||||
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.upsample_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
|
||||
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
|
||||
# =================================================================================================================================================
|
||||
|
||||
def set_weights(self):
|
||||
self.is_overwritten = True
|
||||
if self.ldm:
|
||||
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
|
||||
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
|
||||
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
|
||||
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
|
||||
|
||||
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
|
||||
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
|
||||
|
||||
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
|
||||
for i, input_layer in enumerate(self.input_blocks[1:]):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if layer_in_block_id == 2:
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
|
||||
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.mid.resnets[0].set_weight(self.middle_block[0])
|
||||
self.mid.resnets[1].set_weight(self.middle_block[2])
|
||||
self.mid.attentions[0].set_weight(self.middle_block[1])
|
||||
|
||||
for i, input_layer in enumerate(self.output_blocks):
|
||||
block_id = i // (self.config.num_res_blocks + 1)
|
||||
layer_in_block_id = i % (self.config.num_res_blocks + 1)
|
||||
|
||||
if len(input_layer) > 2:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
|
||||
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
|
||||
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
|
||||
elif len(input_layer) > 1:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
|
||||
else:
|
||||
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
|
||||
|
||||
self.conv_norm_out.weight.data = self.out[0].weight.data
|
||||
self.conv_norm_out.bias.data = self.out[0].bias.data
|
||||
self.conv_out.weight.data = self.out[2].weight.data
|
||||
self.conv_out.bias.data = self.out[2].bias.data
|
||||
|
||||
self.remove_ldm()
|
||||
|
||||
def init_for_ldm(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
channel_mult,
|
||||
num_res_blocks,
|
||||
dropout,
|
||||
time_embed_dim,
|
||||
attention_resolutions,
|
||||
num_head_channels,
|
||||
num_heads,
|
||||
legacy,
|
||||
use_spatial_transformer,
|
||||
transformer_depth,
|
||||
context_dim,
|
||||
conv_resample,
|
||||
out_channels,
|
||||
):
|
||||
# TODO(PVP) - delete after weight conversion
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
# TODO(PVP) - delete after weight conversion
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
dims = 2
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
),
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
|
||||
if dim_head < 0:
|
||||
dim_head = None
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(model_channels, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def remove_ldm(self):
|
||||
del self.time_embed
|
||||
del self.input_blocks
|
||||
del self.middle_block
|
||||
del self.output_blocks
|
||||
del self.out
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlockNew
|
||||
from .attention import AttentionBlockNew, SpatialTransformer
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
|
||||
|
||||
|
||||
|
@ -56,6 +56,18 @@ def get_down_block(
|
|||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "UNetResCrossAttnDownBlock2D":
|
||||
return UNetResCrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "UNetResSkipDownBlock2D":
|
||||
return UNetResSkipDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
|
@ -104,6 +116,18 @@ def get_up_block(
|
|||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
)
|
||||
elif up_block_type == "UNetResCrossAttnUpBlock2D":
|
||||
return UNetResCrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "UNetResAttnUpBlock2D":
|
||||
return UNetResAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
|
@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
|
|||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UNetResCrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UNetResDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class UNetResCrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
|
|||
self.act = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
|
@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
|
|||
self.act = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
@ -15,6 +14,69 @@ from transformers.utils import logging
|
|||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
|
||||
|
||||
if isinstance(pred_noise_t, dict):
|
||||
pred_noise_t = pred_noise_t["sample"]
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
################################################################################
|
||||
# Code for the text transformer model
|
||||
################################################################################
|
||||
|
@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
|||
)
|
||||
sequence_output = outputs[0]
|
||||
return sequence_output
|
||||
|
||||
|
||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
|
||||
num_trained_timesteps = self.scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# guidance_scale of 1 means no guidance
|
||||
if guidance_scale == 1.0:
|
||||
image_in = image
|
||||
context = text_embedding
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
else:
|
||||
# for classifier free guidance, we need to do two forward passes
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embedding])
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
return image
|
||||
|
|
|
@ -40,14 +40,17 @@ from diffusers import (
|
|||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpPipeline,
|
||||
ScoreSdeVpScheduler,
|
||||
UNetConditionalModel,
|
||||
UNetLDMModel,
|
||||
UNetUnconditionalModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
|
||||
@property
|
||||
|
@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skip("Skipping for now as it takes too long")
|
||||
def test_ldm_text2img(self):
|
||||
model_id = "fusing/latent-diffusion-text2im-large"
|
||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
|
@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
model_id = "fusing/latent-diffusion-text2im-large"
|
||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
|
@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
|
|
Loading…
Reference in New Issue