This commit is contained in:
Patrick von Platen 2022-06-27 17:20:20 +00:00
commit a2b72faff7
6 changed files with 91 additions and 98 deletions

View File

@ -101,7 +101,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -109,18 +109,29 @@ class Downsample(nn.Module):
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv:
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.down(x)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
class UNetUpsample(nn.Module):

View File

@ -31,7 +31,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def nonlinearity(x):
@ -43,24 +43,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)

View File

@ -8,7 +8,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def convert_module_to_f16(l):
@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
@ -198,8 +171,8 @@ class ResBlock(TimestepBlock):
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch

View File

@ -1,9 +1,10 @@
import torch
from numpy import pad
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
class Mish(torch.nn.Module):
@ -11,15 +12,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Downsample(torch.nn.Module):
def __init__(self, dim):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Rezero(torch.nn.Module):
def __init__(self, fn):
super(Rezero, self).__init__()
@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity(),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
)
)

View File

@ -10,7 +10,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
from .resnet import Downsample, Upsample
def exists(val):
@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
@ -464,8 +438,8 @@ class ResBlock(TimestepBlock):
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch
@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module):
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
)
)
ch = out_ch

View File

@ -22,7 +22,7 @@ import numpy as np
import torch
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Upsample
from diffusers.models.resnet import Downsample, Upsample
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class DownsampleBlockTests(unittest.TestCase):
def test_downsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=False)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179])
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
assert max_diff <= 1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def test_downsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913],
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_pad1(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, padding=1)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 16, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)