Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
commit
a2b72faff7
|
@ -101,7 +101,7 @@ class Downsample(nn.Module):
|
|||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
@ -109,18 +109,29 @@ class Downsample(nn.Module):
|
|||
self.dims = dims
|
||||
self.padding = padding
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.op = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0 and self.dims == 2:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
return self.down(x)
|
||||
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class UNetUpsample(nn.Module):
|
||||
|
|
|
@ -31,7 +31,7 @@ from tqdm import tqdm
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
@ -43,24 +43,6 @@ def Normalize(in_channels):
|
|||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
|
@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
|
@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
@ -198,8 +171,8 @@ class ResBlock(TimestepBlock):
|
|||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
|
@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
from numpy import pad
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -11,15 +12,6 @@ class Mish(torch.nn.Module):
|
|||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Downsample, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Rezero(torch.nn.Module):
|
||||
def __init__(self, fn):
|
||||
super(Rezero, self).__init__()
|
||||
|
@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
||||
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out) if not is_last else torch.nn.Identity(),
|
||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
|
||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
|
||||
|
@ -464,8 +438,8 @@ class ResBlock(TimestepBlock):
|
|||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
|
@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Upsample
|
||||
from diffusers.models.resnet import Downsample, Upsample
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
|
@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
|
|||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class DownsampleBlockTests(unittest.TestCase):
|
||||
def test_downsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32, 32)
|
||||
output_slice = downsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179])
|
||||
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
|
||||
assert max_diff <= 1e-3
|
||||
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
|
||||
|
||||
def test_downsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32, 32)
|
||||
output_slice = downsampled[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913],
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_downsample_with_conv_pad1(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True, padding=1)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32, 32)
|
||||
output_slice = downsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_downsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 16, 32, 32)
|
||||
output_slice = downsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue