get rid ResnetBlockDDPMpp and related functions
This commit is contained in:
parent
81e7144783
commit
8830af1168
|
@ -637,62 +637,6 @@ class ResnetBlockBigGANpp(nn.Module):
|
|||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
class ResnetBlockDDPMpp(nn.Module):
|
||||
"""ResBlock adapted from DDPM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act,
|
||||
in_ch,
|
||||
out_ch=None,
|
||||
temb_dim=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.1,
|
||||
skip_rescale=False,
|
||||
init_scale=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
if in_ch != out_ch:
|
||||
if conv_shortcut:
|
||||
self.Conv_2 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
self.NIN_0 = NIN(in_ch, out_ch)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
self.out_ch = out_ch
|
||||
self.conv_shortcut = conv_shortcut
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
h = self.Conv_0(h)
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
if x.shape[1] != self.out_ch:
|
||||
if self.conv_shortcut:
|
||||
x = self.Conv_2(x)
|
||||
else:
|
||||
x = self.NIN_0(x)
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
|
@ -957,18 +901,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
|
|||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
|
|
|
@ -28,7 +28,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp
|
||||
from .resnet import ResnetBlockBigGANpp
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
|
@ -305,32 +305,6 @@ def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1
|
|||
return conv
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_act(nonlinearity):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
|
@ -575,30 +549,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
||||
|
||||
if resblock_type == "ddpm":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockDDPMpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
elif resblock_type == "biggan":
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"resblock type {resblock_type} unrecognized.")
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
# Downsampling block
|
||||
|
||||
|
@ -622,10 +582,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Downsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
|
@ -678,10 +635,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Upsample(in_ch=in_ch))
|
||||
else:
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
|
||||
assert not hs_c
|
||||
|
||||
|
@ -741,12 +695,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](hs[-1])
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
|
||||
if self.progressive_input == "input_skip":
|
||||
input_pyramid = self.pyramid_downsample(input_pyramid)
|
||||
|
@ -818,12 +768,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
raise ValueError(f"{self.progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
if self.resblock_type == "ddpm":
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
else:
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
|
||||
|
|
Loading…
Reference in New Issue