cleanup conv1x1 and conv3x3

This commit is contained in:
patil-suraj 2022-06-30 12:20:49 +02:00
parent b897008122
commit ebe683432f
2 changed files with 32 additions and 80 deletions

View File

@ -593,17 +593,18 @@ class ResnetBlockBigGANpp(nn.Module):
self.fir = fir self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) #1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.act = act self.act = act
@ -754,32 +755,19 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""1x1 convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
) )
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
def default_init(scale=1.0): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
@ -789,21 +777,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = (fan_in + fan_out) / 2
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal": return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init

View File

@ -287,20 +287,12 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""1x1 convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
) )
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
@ -320,14 +312,9 @@ def get_act(nonlinearity):
raise NotImplementedError("activation function does not exist!") raise NotImplementedError("activation function does not exist!")
def default_init(scale=1.0): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
@ -337,21 +324,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = (fan_in + fan_out) / 2
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal": return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
@ -361,7 +336,8 @@ class Combine(nn.Module):
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
self.Conv_0 = conv1x1(dim1, dim2) #1x1 convolution with DDPM initialization.
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
def forward(self, x, y): def forward(self, x, y):
@ -386,7 +362,7 @@ class Upsample(nn.Module):
up=True, up=True,
resample_kernel=fir_kernel, resample_kernel=fir_kernel,
use_bias=True, use_bias=True,
kernel_init=default_init(), kernel_init=variance_scaling(),
) )
self.fir = fir self.fir = fir
self.with_conv = with_conv self.with_conv = with_conv
@ -415,7 +391,7 @@ class Downsample(nn.Module):
down=True, down=True,
resample_kernel=fir_kernel, resample_kernel=fir_kernel,
use_bias=True, use_bias=True,
kernel_init=default_init(), kernel_init=variance_scaling(),
) )
self.fir = fir self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
@ -528,10 +504,10 @@ class NCSNpp(ModelMixin, ConfigMixin):
if conditional: if conditional:
modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape) modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4)) modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape) modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
@ -566,7 +542,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive_input != "none": if progressive_input != "none":
input_pyramid_ch = channels input_pyramid_ch = channels
modules.append(conv3x3(channels, nf)) modules.append(conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf] hs_c = [nf]
in_ch = nf in_ch = nf
@ -615,18 +591,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
if i_level == self.num_resolutions - 1: if i_level == self.num_resolutions - 1:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, in_ch, bias=True)) modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name.") raise ValueError(f"{progressive} is not a valid name.")
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1))
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
@ -641,7 +617,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if progressive != "output_skip": if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) modules.append(conv2d(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)