cleanup conv1x1 and conv3x3
This commit is contained in:
parent
b897008122
commit
ebe683432f
|
@ -593,17 +593,18 @@ class ResnetBlockBigGANpp(nn.Module):
|
|||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
self.Conv_2 = conv1x1(in_ch, out_ch)
|
||||
#1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
|
@ -754,32 +755,19 @@ class RearrangeDim(nn.Module):
|
|||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
|
@ -789,21 +777,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
|||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
|
|
|
@ -287,20 +287,12 @@ def downsample_2d(x, k=None, factor=2, gain=1):
|
|||
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
|
||||
"""nXn convolution with DDPM initialization."""
|
||||
conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
|
||||
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
@ -320,14 +312,9 @@ def get_act(nonlinearity):
|
|||
raise NotImplementedError("activation function does not exist!")
|
||||
|
||||
|
||||
def default_init(scale=1.0):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, "fan_avg", "uniform")
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
|
@ -337,21 +324,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
|
|||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
|
||||
|
||||
return init
|
||||
|
||||
|
@ -361,7 +336,8 @@ class Combine(nn.Module):
|
|||
|
||||
def __init__(self, dim1, dim2, method="cat"):
|
||||
super().__init__()
|
||||
self.Conv_0 = conv1x1(dim1, dim2)
|
||||
#1x1 convolution with DDPM initialization.
|
||||
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
|
||||
self.method = method
|
||||
|
||||
def forward(self, x, y):
|
||||
|
@ -386,7 +362,7 @@ class Upsample(nn.Module):
|
|||
up=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
kernel_init=variance_scaling(),
|
||||
)
|
||||
self.fir = fir
|
||||
self.with_conv = with_conv
|
||||
|
@ -415,7 +391,7 @@ class Downsample(nn.Module):
|
|||
down=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
kernel_init=variance_scaling(),
|
||||
)
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
@ -528,10 +504,10 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
|
||||
if conditional:
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
|
||||
modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
|
||||
nn.init.zeros_(modules[-1].bias)
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
|
@ -566,7 +542,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
modules.append(conv3x3(channels, nf))
|
||||
modules.append(conv2d(channels, nf, kernel_size=3, padding=1))
|
||||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
|
@ -615,18 +591,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
||||
modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name.")
|
||||
else:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
||||
modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1))
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
||||
|
@ -641,7 +617,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
|
||||
if progressive != "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
modules.append(conv2d(in_ch, channels, init_scale=init_scale))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
|
|
Loading…
Reference in New Issue