diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 6560d345..ca72669a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -101,7 +101,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -109,18 +109,29 @@ class Downsample(nn.Module): self.dims = dims self.padding = padding stride = 2 if dims != 3 else (1, 2, 2) + self.name = name + if use_conv: - self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels - self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + if name == "conv": + self.conv = conv + else: + self.op = conv def forward(self, x): assert x.shape[1] == self.channels if self.use_conv and self.padding == 0 and self.dims == 2: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) - return self.down(x) + + if self.name == "conv": + return self.conv(x) + else: + return self.op(x) class UNetUpsample(nn.Module): diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index fe8802cc..c40f393e 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -31,7 +31,7 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def nonlinearity(x): @@ -43,24 +43,6 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() @@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin): down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) + down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) curr_res = curr_res // 2 self.down.append(down) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index d357d0cc..6fe27959 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def convert_module_to_f16(l): @@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. @@ -198,8 +171,8 @@ class ResBlock(TimestepBlock): self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") else: self.h_upd = self.x_upd = nn.Identity() @@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 9c1974e6..4cccc837 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,9 +1,10 @@ import torch +from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample class Mish(torch.nn.Module): @@ -11,15 +12,6 @@ class Mish(torch.nn.Module): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Downsample(torch.nn.Module): - def __init__(self, dim): - super(Downsample, self).__init__() - self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - class Rezero(torch.nn.Module): def __init__(self, fn): super(Rezero, self).__init__() @@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(dim_out, dim_out, time_emb_dim=dim), Residual(Rezero(LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else torch.nn.Identity(), + Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), ] ) ) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index f8a8602d..24aec1cf 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -from .resnet import Upsample +from .resnet import Downsample, Upsample def exists(val): @@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): return x -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param - use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. - If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. @@ -464,8 +438,8 @@ class ResBlock(TimestepBlock): self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") + self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") else: self.h_upd = self.x_upd = nn.Identity() @@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch @@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module): down=True, ) if resblock_updown - else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Downsample( + ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" + ) ) ) ch = out_ch diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index cde7fc6b..86a7b883 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ import numpy as np import torch from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.resnet import Upsample +from diffusers.models.resnet import Downsample, Upsample from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase): output_slice = upsampled[0, -1, -3:, -3:] expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class DownsampleBlockTests(unittest.TestCase): + def test_downsample_default(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=False) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179]) + max_diff = (output_slice.flatten() - expected_slice).abs().sum().item() + assert max_diff <= 1e-3 + # assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1) + + def test_downsample_with_conv(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913], + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_downsample_with_conv_pad1(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True, padding=1) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 32, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_downsample_with_conv_out_dim(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64) + downsample = Downsample(channels=32, use_conv=True, out_channels=16) + with torch.no_grad(): + downsampled = downsample(sample) + + assert downsampled.shape == (1, 16, 32, 32) + output_slice = downsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)