update
This commit is contained in:
parent
1468f754e0
commit
14bd3567b0
|
@ -1,6 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import functools
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -160,211 +159,7 @@ class Downsample(nn.Module):
|
||||||
# return self.conv(x)
|
# return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
# RESNETS
|
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
|
||||||
# unet_score_estimation.py
|
|
||||||
class ResnetBlockBigGANppNew(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
act,
|
|
||||||
in_ch,
|
|
||||||
out_ch=None,
|
|
||||||
temb_dim=None,
|
|
||||||
up=False,
|
|
||||||
down=False,
|
|
||||||
dropout=0.1,
|
|
||||||
fir_kernel=(1, 3, 3, 1),
|
|
||||||
skip_rescale=True,
|
|
||||||
init_scale=0.0,
|
|
||||||
overwrite=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
out_ch = out_ch if out_ch else in_ch
|
|
||||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
|
||||||
self.up = up
|
|
||||||
self.down = down
|
|
||||||
self.fir_kernel = fir_kernel
|
|
||||||
|
|
||||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
|
||||||
if temb_dim is not None:
|
|
||||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
|
||||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
|
||||||
nn.init.zeros_(self.Dense_0.bias)
|
|
||||||
|
|
||||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
|
||||||
self.Dropout_0 = nn.Dropout(dropout)
|
|
||||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
|
|
||||||
if in_ch != out_ch or up or down:
|
|
||||||
# 1x1 convolution with DDPM initialization.
|
|
||||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
|
||||||
|
|
||||||
self.skip_rescale = skip_rescale
|
|
||||||
self.act = act
|
|
||||||
self.in_ch = in_ch
|
|
||||||
self.out_ch = out_ch
|
|
||||||
|
|
||||||
self.is_overwritten = False
|
|
||||||
self.overwrite = overwrite
|
|
||||||
if overwrite:
|
|
||||||
self.output_scale_factor = np.sqrt(2.0)
|
|
||||||
self.in_channels = in_channels = in_ch
|
|
||||||
self.out_channels = out_channels = out_ch
|
|
||||||
groups = min(in_ch // 4, 32)
|
|
||||||
out_groups = min(out_ch // 4, 32)
|
|
||||||
eps = 1e-6
|
|
||||||
self.pre_norm = True
|
|
||||||
temb_channels = temb_dim
|
|
||||||
non_linearity = "silu"
|
|
||||||
self.time_embedding_norm = time_embedding_norm = "default"
|
|
||||||
|
|
||||||
if self.pre_norm:
|
|
||||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
|
||||||
else:
|
|
||||||
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
|
|
||||||
|
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
|
|
||||||
if time_embedding_norm == "default":
|
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
||||||
elif time_embedding_norm == "scale_shift":
|
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
|
||||||
|
|
||||||
self.norm2 = Normalize(out_channels, num_groups=out_groups, eps=eps)
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
|
|
||||||
if non_linearity == "swish":
|
|
||||||
self.nonlinearity = nonlinearity
|
|
||||||
elif non_linearity == "mish":
|
|
||||||
self.nonlinearity = Mish()
|
|
||||||
elif non_linearity == "silu":
|
|
||||||
self.nonlinearity = nn.SiLU()
|
|
||||||
|
|
||||||
if up:
|
|
||||||
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
|
||||||
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
|
||||||
elif down:
|
|
||||||
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
|
||||||
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels or self.up or self.down:
|
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
def set_weights(self):
|
|
||||||
self.conv1.weight.data = self.Conv_0.weight.data
|
|
||||||
self.conv1.bias.data = self.Conv_0.bias.data
|
|
||||||
self.norm1.weight.data = self.GroupNorm_0.weight.data
|
|
||||||
self.norm1.bias.data = self.GroupNorm_0.bias.data
|
|
||||||
|
|
||||||
self.conv2.weight.data = self.Conv_1.weight.data
|
|
||||||
self.conv2.bias.data = self.Conv_1.bias.data
|
|
||||||
self.norm2.weight.data = self.GroupNorm_1.weight.data
|
|
||||||
self.norm2.bias.data = self.GroupNorm_1.bias.data
|
|
||||||
|
|
||||||
self.temb_proj.weight.data = self.Dense_0.weight.data
|
|
||||||
self.temb_proj.bias.data = self.Dense_0.bias.data
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels or self.up or self.down:
|
|
||||||
self.nin_shortcut.weight.data = self.Conv_2.weight.data
|
|
||||||
self.nin_shortcut.bias.data = self.Conv_2.bias.data
|
|
||||||
|
|
||||||
def forward(self, x, temb=None):
|
|
||||||
if self.overwrite and not self.is_overwritten:
|
|
||||||
self.set_weights()
|
|
||||||
self.is_overwritten = True
|
|
||||||
|
|
||||||
orig_x = x
|
|
||||||
h = self.act(self.GroupNorm_0(x))
|
|
||||||
|
|
||||||
if self.up:
|
|
||||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
elif self.down:
|
|
||||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
|
|
||||||
h = self.Conv_0(h)
|
|
||||||
# Add bias to each feature map conditioned on the time embedding
|
|
||||||
if temb is not None:
|
|
||||||
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
|
||||||
h = self.act(self.GroupNorm_1(h))
|
|
||||||
h = self.Dropout_0(h)
|
|
||||||
h = self.Conv_1(h)
|
|
||||||
|
|
||||||
if self.in_ch != self.out_ch or self.up or self.down:
|
|
||||||
x = self.Conv_2(x)
|
|
||||||
|
|
||||||
if not self.skip_rescale:
|
|
||||||
raise ValueError("Is this branch run?!")
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
result = x + h
|
|
||||||
else:
|
|
||||||
result = (x + h) / np.sqrt(2.0)
|
|
||||||
|
|
||||||
result_2 = self.forward_2(orig_x, temb)
|
|
||||||
|
|
||||||
return result_2
|
|
||||||
|
|
||||||
def forward_2(self, x, temb, mask=1.0):
|
|
||||||
h = x
|
|
||||||
h = h * mask
|
|
||||||
if self.pre_norm:
|
|
||||||
h = self.norm1(h)
|
|
||||||
h = self.nonlinearity(h)
|
|
||||||
|
|
||||||
# if self.up or self.down:
|
|
||||||
# x = self.x_upd(x)
|
|
||||||
# h = self.h_upd(h)
|
|
||||||
if self.up:
|
|
||||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
elif self.down:
|
|
||||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
|
|
||||||
h = self.conv1(h)
|
|
||||||
|
|
||||||
if not self.pre_norm:
|
|
||||||
h = self.norm1(h)
|
|
||||||
h = self.nonlinearity(h)
|
|
||||||
h = h * mask
|
|
||||||
|
|
||||||
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
|
||||||
|
|
||||||
if self.time_embedding_norm == "scale_shift":
|
|
||||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
|
||||||
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = h + h * scale + shift
|
|
||||||
h = self.nonlinearity(h)
|
|
||||||
elif self.time_embedding_norm == "default":
|
|
||||||
h = h + temb
|
|
||||||
h = h * mask
|
|
||||||
if self.pre_norm:
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = self.nonlinearity(h)
|
|
||||||
else:
|
|
||||||
raise ValueError("Nananan nanana - don't go here!")
|
|
||||||
|
|
||||||
h = self.dropout(h)
|
|
||||||
h = self.conv2(h)
|
|
||||||
|
|
||||||
if not self.pre_norm:
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = self.nonlinearity(h)
|
|
||||||
h = h * mask
|
|
||||||
|
|
||||||
x = x * mask
|
|
||||||
# if self.in_channels != self.out_channels:
|
|
||||||
if self.in_channels != self.out_channels or self.up or self.down:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
result = x + h
|
|
||||||
|
|
||||||
return result / self.output_scale_factor
|
|
||||||
|
|
||||||
|
|
||||||
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
groups_out=min(out_ch // 4, 32),
|
groups_out=min(out_ch // 4, 32),
|
||||||
overwrite_for_score_vde=True,
|
overwrite_for_score_vde=True,
|
||||||
up=True,
|
up=True,
|
||||||
kernel="fir",
|
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||||
use_nin_shortcut=True,
|
use_nin_shortcut=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue