From ebe683432f1cd4703947ad5691dbc0aaededa7d2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 30 Jun 2022 12:20:49 +0200 Subject: [PATCH] cleanup conv1x1 and conv3x3 --- src/diffusers/models/resnet.py | 50 ++++----------- .../models/unet_sde_score_estimation.py | 62 ++++++------------- 2 files changed, 32 insertions(+), 80 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 50a3e453..c206859b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -593,17 +593,18 @@ class ResnetBlockBigGANpp(nn.Module): self.fir = fir self.fir_kernel = fir_kernel - self.Conv_0 = conv3x3(in_ch, out_ch) + self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape) nn.init.zeros_(self.Dense_0.bias) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1) if in_ch != out_ch or up or down: - self.Conv_2 = conv1x1(in_ch, out_ch) + #1x1 convolution with DDPM initialization. + self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) self.skip_rescale = skip_rescale self.act = act @@ -754,32 +755,19 @@ class RearrangeDim(nn.Module): raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") -def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): - """1x1 convolution with DDPM initialization.""" - conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) - conv.weight.data = default_init(init_scale)(conv.weight.data.shape) - nn.init.zeros_(conv.bias) - return conv - - -def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): - """3x3 convolution with DDPM initialization.""" +def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): + """nXn convolution with DDPM initialization.""" conv = nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) - conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv -def default_init(scale=1.0): - """The same initialization used in DDPM.""" - scale = 1e-10 if scale == 0 else scale - return variance_scaling(scale, "fan_avg", "uniform") - - -def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): +def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): """Ported from JAX.""" + scale = 1e-10 if scale == 0 else scale def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] @@ -789,21 +777,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) - if mode == "fan_in": - denominator = fan_in - elif mode == "fan_out": - denominator = fan_out - elif mode == "fan_avg": - denominator = (fan_in + fan_out) / 2 - else: - raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + denominator = (fan_in + fan_out) / 2 variance = scale / denominator - if distribution == "normal": - return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) - elif distribution == "uniform": - return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) - else: - raise ValueError("invalid distribution for variance scaling initializer") + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) return init diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 890678f1..30db3493 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -287,20 +287,12 @@ def downsample_2d(x, k=None, factor=2, gain=1): return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) -def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): - """1x1 convolution with DDPM initialization.""" - conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) - conv.weight.data = default_init(init_scale)(conv.weight.data.shape) - nn.init.zeros_(conv.bias) - return conv - - -def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): - """3x3 convolution with DDPM initialization.""" +def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): + """nXn convolution with DDPM initialization.""" conv = nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) - conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv @@ -320,14 +312,9 @@ def get_act(nonlinearity): raise NotImplementedError("activation function does not exist!") -def default_init(scale=1.0): - """The same initialization used in DDPM.""" - scale = 1e-10 if scale == 0 else scale - return variance_scaling(scale, "fan_avg", "uniform") - - -def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): +def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): """Ported from JAX.""" + scale = 1e-10 if scale == 0 else scale def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] @@ -337,21 +324,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) - if mode == "fan_in": - denominator = fan_in - elif mode == "fan_out": - denominator = fan_out - elif mode == "fan_avg": - denominator = (fan_in + fan_out) / 2 - else: - raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + denominator = (fan_in + fan_out) / 2 variance = scale / denominator - if distribution == "normal": - return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) - elif distribution == "uniform": - return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) - else: - raise ValueError("invalid distribution for variance scaling initializer") + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) return init @@ -361,7 +336,8 @@ class Combine(nn.Module): def __init__(self, dim1, dim2, method="cat"): super().__init__() - self.Conv_0 = conv1x1(dim1, dim2) + #1x1 convolution with DDPM initialization. + self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0) self.method = method def forward(self, x, y): @@ -386,7 +362,7 @@ class Upsample(nn.Module): up=True, resample_kernel=fir_kernel, use_bias=True, - kernel_init=default_init(), + kernel_init=variance_scaling(), ) self.fir = fir self.with_conv = with_conv @@ -415,7 +391,7 @@ class Downsample(nn.Module): down=True, resample_kernel=fir_kernel, use_bias=True, - kernel_init=default_init(), + kernel_init=variance_scaling(), ) self.fir = fir self.fir_kernel = fir_kernel @@ -528,10 +504,10 @@ class NCSNpp(ModelMixin, ConfigMixin): if conditional: modules.append(nn.Linear(embed_dim, nf * 4)) - modules[-1].weight.data = default_init()(modules[-1].weight.shape) + modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) modules.append(nn.Linear(nf * 4, nf * 4)) - modules[-1].weight.data = default_init()(modules[-1].weight.shape) + modules[-1].weight.data = variance_scaling()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) @@ -566,7 +542,7 @@ class NCSNpp(ModelMixin, ConfigMixin): if progressive_input != "none": input_pyramid_ch = channels - modules.append(conv3x3(channels, nf)) + modules.append(conv2d(channels, nf, kernel_size=3, padding=1)) hs_c = [nf] in_ch = nf @@ -615,18 +591,18 @@ class NCSNpp(ModelMixin, ConfigMixin): if i_level == self.num_resolutions - 1: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + modules.append(conv2d(in_ch, channels, init_scale=init_scale, kernel_size=3, padding=1)) pyramid_ch = channels elif progressive == "residual": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv3x3(in_ch, in_ch, bias=True)) + modules.append(conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) pyramid_ch = in_ch else: raise ValueError(f"{progressive} is not a valid name.") else: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) + modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)) pyramid_ch = channels elif progressive == "residual": modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) @@ -641,7 +617,7 @@ class NCSNpp(ModelMixin, ConfigMixin): if progressive != "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + modules.append(conv2d(in_ch, channels, init_scale=init_scale)) self.all_modules = nn.ModuleList(modules)