From 4d1536bb2e141722a6a6fa53294b5a61d5f0ade7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:27 +0200 Subject: [PATCH 1/9] add vae model --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/vae.py | 669 +++++++++++++++++++++++++++++++ 3 files changed, 671 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/vae.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 99e08dee..b62139cd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel +from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, VQModel from .pipeline_utils import DiffusionPipeline from .pipelines import ( BDDMPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 71e321e1..3bba1133 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet from .unet_sde_score_estimation import NCSNpp +from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py new file mode 100644 index 00000000..28244620 --- /dev/null +++ b/src/diffusers/models/vae.py @@ -0,0 +1,669 @@ +import math + +import numpy as np +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from .attention import AttentionBlock +from .resnet import Downsample, Upsample + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal + embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section + 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +# class AttnBlock(nn.Module): +# def __init__(self, in_channels): +# super().__init__() +# self.in_channels = in_channels + +# self.norm = Normalize(in_channels) +# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) + +# # compute attention +# b, c, h, w = q.shape +# q = q.reshape(b, c, h * w) +# q = q.permute(0, 2, 1) # b,hw,c +# k = k.reshape(b, c, h * w) # b,c,hw +# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] +# w_ = w_ * (int(c) ** (-0.5)) +# w_ = torch.nn.functional.softmax(w_, dim=2) + +# # attend to values +# v = v.reshape(b, c, h * w) +# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) +# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] +# h_ = h_.reshape(b, c, h, w) + +# h_ = self.proj_out(h_) + +# return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, use_conv=resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + n_embed, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + n_embed=n_embed, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior From 36e1893c6f3d171bfff288ef5fa3f66fba2bc177 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:38 +0200 Subject: [PATCH 2/9] remove vae from ldm pipeline --- .../pipeline_latent_diffusion.py | 849 ------------------ 1 file changed, 849 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f7dc8ea3..6a4eb162 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,4 +1,3 @@ -import math from typing import Optional, Tuple, Union import numpy as np @@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from ...configuration_utils import ConfigMixin -from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline @@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel): return sequence_output -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - ): - super().__init__() - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, x, t=None): - # assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits == False, "Only for interface compatible with Gumbel" - assert return_logits == False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class AutoencoderKL(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - class LatentDiffusionPipeline(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__() From 17e5b4921a2dae2b9f33b3cb603d328f94e2d36f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:48 +0200 Subject: [PATCH 3/9] remove vae from ldm uncond pipe --- .../pipeline_latent_diffusion_uncond.py | 561 ------------------ 1 file changed, 561 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 269dfb39..873683b0 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,571 +1,10 @@ -import math - -import numpy as np import torch -import torch.nn as nn import tqdm -from ...configuration_utils import ConfigMixin -from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - class LatentDiffusionUncondPipeline(DiffusionPipeline): def __init__(self, vqvae, unet, noise_scheduler): super().__init__() From 2ac9b026097d44d117887c61761e61ddbc681909 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:43:04 +0200 Subject: [PATCH 4/9] remove AutoencoderKL from pipe __init__ --- src/diffusers/pipelines/latent_diffusion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/latent_diffusion/__init__.py b/src/diffusers/pipelines/latent_diffusion/__init__.py index 614c3db4..5ce717ca 100644 --- a/src/diffusers/pipelines/latent_diffusion/__init__.py +++ b/src/diffusers/pipelines/latent_diffusion/__init__.py @@ -2,4 +2,4 @@ from ...utils import is_transformers_available if is_transformers_available(): - from .pipeline_latent_diffusion import AutoencoderKL, LatentDiffusionPipeline, LDMBertModel + from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel From 99568c5a39cb729591963c8a4bbfb16b7f76d253 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:53:58 +0200 Subject: [PATCH 5/9] cleanup vae file --- src/diffusers/models/vae.py | 38 ------------------------------------- 1 file changed, 38 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 28244620..282cb14a 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -82,44 +82,6 @@ class ResnetBlock(nn.Module): return x + h -# class AttnBlock(nn.Module): -# def __init__(self, in_channels): -# super().__init__() -# self.in_channels = in_channels - -# self.norm = Normalize(in_channels) -# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - -# def forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) - -# # compute attention -# b, c, h, w = q.shape -# q = q.reshape(b, c, h * w) -# q = q.permute(0, 2, 1) # b,hw,c -# k = k.reshape(b, c, h * w) # b,c,hw -# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] -# w_ = w_ * (int(c) ** (-0.5)) -# w_ = torch.nn.functional.softmax(w_, dim=2) - -# # attend to values -# v = v.reshape(b, c, h * w) -# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) -# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] -# h_ = h_.reshape(b, c, h, w) - -# h_ = self.proj_out(h_) - -# return x + h_ - - class Encoder(nn.Module): def __init__( self, From 0b7daa6de99b759f4c6c0b723716b70e4107a2c8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:56:19 +0200 Subject: [PATCH 6/9] add forward for vq model --- src/diffusers/models/vae.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 282cb14a..1a4c8904 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -534,6 +534,11 @@ class VQModel(ModelMixin, ConfigMixin): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec + + def forward(self, x): + h = self.encode(x) + dec = self.decode(h) + return dec class AutoencoderKL(ModelMixin, ConfigMixin): From bae04ea9d891e22de9c66247d2429afebfa45af1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 12:34:24 +0200 Subject: [PATCH 7/9] add test for VQModel --- tests/test_modeling_utils.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2e0be1d3..f38e629d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers import ( + AutoencoderKL, BDDMPipeline, DDIMPipeline, DDIMScheduler, @@ -44,6 +45,7 @@ from diffusers import ( UNetGradTTSModel, UNetLDMModel, UNetModel, + VQModel, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -805,6 +807,82 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class VQModelTests(ModelTesterMixin, unittest.TestCase): + model_class = VQModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "out_ch": 3, + "num_res_blocks": 1, + "attn_resolutions": [], + "in_channels": 3, + "resolution": 32, + "z_channels": 3, + "n_embed": 256, + "embed_dim": 3, + "sane_index_shape": False, + "ch_mult": (1,), + "dropout": 0.0, + "double_z": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, + -0.4218]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models From 976173a4bfab7ba5a091a1901e7ebb9d69d2e32d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 12:34:28 +0200 Subject: [PATCH 8/9] style --- src/diffusers/models/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 1a4c8904..bfc4a96e 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -534,7 +534,7 @@ class VQModel(ModelMixin, ConfigMixin): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec - + def forward(self, x): h = self.encode(x) dec = self.decode(h) From 333a8da678882930a3c981df09128c7d8ffd0212 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 13:52:04 +0200 Subject: [PATCH 9/9] add tests for AutoencoderKL --- src/diffusers/models/vae.py | 6 +-- tests/test_modeling_utils.py | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bfc4a96e..6bd9a070 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) + def forward(self, x, sample_posterior=False): + posterior = self.encode(x) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) - return dec, posterior + return dec diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f38e629d..c567fdbb 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -46,6 +46,7 @@ from diffusers import ( UNetLDMModel, UNetModel, VQModel, + AutoencoderKL, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "ch_mult": (1,), + "embed_dim": 4, + "in_channels": 3, + "num_res_blocks": 1, + "out_ch": 3, + "resolution": 32, + "z_channels": 4, + "attn_resolutions": [] + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image, sample_posterior=True) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, + 0.1750]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models