From b9de7172baf5cf755ab334ed48348bb6b3a1f61e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 18:03:41 +0200 Subject: [PATCH 1/3] add Downsample --- src/diffusers/models/resnet.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 8d877869..34963251 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -103,7 +103,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -111,18 +111,29 @@ class Downsample(nn.Module): self.dims = dims self.padding = padding stride = 2 if dims != 3 else (1, 2, 2) + self.name = name + if use_conv: - self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels - self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + if name == "conv": + self.conv = conv + else: + self.op = conv def forward(self, x): assert x.shape[1] == self.channels if self.use_conv and self.padding == 0 and self.dims == 2: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) - return self.down(x) + + if self.name == "conv": + return self.conv(x) + else: + return self.op(x) # TODO (patil-suraj): needs test From 7b9b946cb2539487a85a6dcca8a0aa52da4a8b61 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 18:03:51 +0200 Subject: [PATCH 2/3] add tests for downsample block --- tests/test_layers_utils.py | 57 +++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index cde7fc6b..86a7b883 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ import numpy as np import torch from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.resnet import Upsample +from diffusers.models.resnet import Downsample, Upsample from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase): output_slice = upsampled[0, -1, -3:, -3:] expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class DownsampleBlockTests(unittest.TestCase): + def test_downsample_default(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=False) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179]) + max_diff = (output_slice.flatten() - expected_slice).abs().sum().item() + assert max_diff <= 1e-3 + # assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1) + + def test_downsample_with_conv(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913], + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_downsample_with_conv_pad1(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True, padding=1) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_downsample_with_conv_out_dim(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True, out_channels=16) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 16, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) From d1fb309381c7bdfa539838d5c94b699c9790a246 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 18:03:59 +0200 Subject: [PATCH 3/3] consolidate downsample --- src/diffusers/models/unet.py | 22 ++------------ src/diffusers/models/unet_glide.py | 38 ++++--------------------- src/diffusers/models/unet_grad_tts.py | 14 ++------- src/diffusers/models/unet_ldm.py | 41 ++++++--------------------- 4 files changed, 20 insertions(+), 95 deletions(-) diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index fe8802cc..c40f393e 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -31,7 +31,7 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def nonlinearity(x): @@ -43,24 +43,6 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() @@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin): down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) + down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 self.down.append(down) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 9a50b9cb..da753ce0 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def convert_module_to_f16(l): @@ -126,34 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. @@ -205,8 +177,8 @@ class ResBlock(TimestepBlock): self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") else: self.h_upd = self.x_upd = nn.Identity() @@ -463,7 +435,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index e9666f74..84ec622b 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,4 +1,5 @@ import torch +from numpy import pad try: @@ -10,7 +11,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample class Mish(torch.nn.Module): @@ -18,15 +19,6 @@ class Mish(torch.nn.Module): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Downsample(torch.nn.Module): - def __init__(self, dim): - super(Downsample, self).__init__() - self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - class Rezero(torch.nn.Module): def __init__(self, fn): super(Rezero, self).__init__() @@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(dim_out, dim_out, time_emb_dim=dim), Residual(Rezero(LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else torch.nn.Identity(), + Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), ] ) ) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 26aab775..7ebd2f48 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -17,7 +17,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def exists(val): @@ -380,33 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. @@ -457,8 +430,8 @@ class ResBlock(TimestepBlock): self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") else: self.h_upd = self.x_upd = nn.Identity() @@ -825,7 +798,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch @@ -1098,7 +1073,9 @@ class EncoderUNetModel(nn.Module): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch