get rid ResnetBlockDDPMpp and related functions

This commit is contained in:
patil-suraj 2022-06-30 11:54:32 +02:00
parent 81e7144783
commit 8830af1168
2 changed files with 17 additions and 139 deletions

View File

@ -637,62 +637,6 @@ class ResnetBlockBigGANpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_score_estimation.py
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py # unet_rl.py
class ResidualTemporalBlock(nn.Module): class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
@ -957,18 +901,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def _setup_kernel(k): def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1: if k.ndim == 1:

View File

@ -28,7 +28,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp from .resnet import ResnetBlockBigGANpp
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
@ -305,32 +305,6 @@ def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1
return conv return conv
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def get_act(nonlinearity): def get_act(nonlinearity):
"""Get activation functions from the config file.""" """Get activation functions from the config file."""
@ -575,30 +549,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
if resblock_type == "ddpm": ResnetBlock = functools.partial(
ResnetBlock = functools.partial( ResnetBlockBigGANpp,
ResnetBlockDDPMpp, act=act,
act=act, dropout=dropout,
dropout=dropout, fir=fir,
init_scale=init_scale, fir_kernel=fir_kernel,
skip_rescale=skip_rescale, init_scale=init_scale,
temb_dim=nf * 4, skip_rescale=skip_rescale,
) temb_dim=nf * 4,
)
elif resblock_type == "biggan":
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
else:
raise ValueError(f"resblock type {resblock_type} unrecognized.")
# Downsampling block # Downsampling block
@ -622,10 +582,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch) hs_c.append(in_ch)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if resblock_type == "ddpm": modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append(Downsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
if progressive_input == "input_skip": if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
@ -678,10 +635,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if resblock_type == "ddpm": modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append(Upsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
assert not hs_c assert not hs_c
@ -741,12 +695,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
if self.resblock_type == "ddpm": h = modules[m_idx](hs[-1], temb)
h = modules[m_idx](hs[-1]) m_idx += 1
m_idx += 1
else:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == "input_skip": if self.progressive_input == "input_skip":
input_pyramid = self.pyramid_downsample(input_pyramid) input_pyramid = self.pyramid_downsample(input_pyramid)
@ -818,12 +768,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{self.progressive} is not a valid name") raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0: if i_level != 0:
if self.resblock_type == "ddpm": h = modules[m_idx](h, temb)
h = modules[m_idx](h) m_idx += 1
m_idx += 1
else:
h = modules[m_idx](h, temb)
m_idx += 1
assert not hs assert not hs