Downsample / Upsample - clean to 1D and 2D (#68)
* make unet rl work * uploaad files / code * upload files * make style correct * finish
This commit is contained in:
parent
c524244f49
commit
321f9791d6
|
@ -6,46 +6,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
class Upsample2D(nn.Module):
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D average pooling module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.AvgPool1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.AvgPool2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.AvgPool3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.Conv1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.Conv2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.Conv3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def conv_transpose_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.ConvTranspose1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.ConvTranspose2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.ConvTranspose3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
An upsampling layer with an optional convolution.
|
An upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
@ -54,21 +15,21 @@ class Upsample(nn.Module):
|
||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
|
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
|
||||||
self.use_conv_transpose = use_conv_transpose
|
self.use_conv_transpose = use_conv_transpose
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
conv = None
|
conv = None
|
||||||
if use_conv_transpose:
|
if use_conv_transpose:
|
||||||
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
if name == "conv":
|
if name == "conv":
|
||||||
self.conv = conv
|
self.conv = conv
|
||||||
else:
|
else:
|
||||||
|
@ -79,11 +40,9 @@ class Upsample(nn.Module):
|
||||||
if self.use_conv_transpose:
|
if self.use_conv_transpose:
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
if self.dims == 3:
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
|
||||||
else:
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
|
|
||||||
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
if self.name == "conv":
|
if self.name == "conv":
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
@ -93,7 +52,7 @@ class Upsample(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample2D(nn.Module):
|
||||||
"""
|
"""
|
||||||
A downsampling layer with an optional convolution.
|
A downsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
@ -102,22 +61,22 @@ class Downsample(nn.Module):
|
||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
|
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
if use_conv:
|
if use_conv:
|
||||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||||
|
|
||||||
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
if name == "conv":
|
if name == "conv":
|
||||||
self.conv = conv
|
self.conv = conv
|
||||||
elif name == "Conv2d_0":
|
elif name == "Conv2d_0":
|
||||||
|
@ -127,10 +86,11 @@ class Downsample(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
if self.use_conv and self.padding == 0 and self.dims == 2:
|
if self.use_conv and self.padding == 0:
|
||||||
pad = (0, 1, 0, 1)
|
pad = (0, 1, 0, 1)
|
||||||
x = F.pad(x, pad, mode="constant", value=0)
|
x = F.pad(x, pad, mode="constant", value=0)
|
||||||
|
|
||||||
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
if self.name == "conv":
|
if self.name == "conv":
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
elif self.name == "Conv2d_0":
|
elif self.name == "Conv2d_0":
|
||||||
|
@ -139,8 +99,204 @@ class Downsample(nn.Module):
|
||||||
return self.op(x)
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1D(nn.Module):
|
||||||
|
"""
|
||||||
|
An upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||||
|
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
upsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.use_conv_transpose = use_conv_transpose
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
|
self.conv = None
|
||||||
|
if use_conv_transpose:
|
||||||
|
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||||
|
elif use_conv:
|
||||||
|
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
if self.use_conv_transpose:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample1D(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||||
|
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
downsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.padding = padding
|
||||||
|
stride = 2
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
if use_conv:
|
||||||
|
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||||
|
else:
|
||||||
|
assert self.channels == self.out_channels
|
||||||
|
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FirUpsample2D(nn.Module):
|
||||||
|
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
|
super().__init__()
|
||||||
|
out_channels = out_channels if out_channels else channels
|
||||||
|
if use_conv:
|
||||||
|
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.fir_kernel = fir_kernel
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||||
|
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||||
|
else:
|
||||||
|
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class FirDownsample2D(nn.Module):
|
||||||
|
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
|
super().__init__()
|
||||||
|
out_channels = out_channels if out_channels else channels
|
||||||
|
if use_conv:
|
||||||
|
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.fir_kernel = fir_kernel
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||||
|
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||||
|
else:
|
||||||
|
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||||
|
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||||
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||||
|
order.
|
||||||
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||||
|
C]`.
|
||||||
|
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||||
|
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||||
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||||
|
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||||
|
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
|
||||||
|
as `x`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(factor, int) and factor >= 1
|
||||||
|
_outC, _inC, convH, convW = w.shape
|
||||||
|
assert convW == convH
|
||||||
|
if k is None:
|
||||||
|
k = [1] * factor
|
||||||
|
k = _setup_kernel(k) * gain
|
||||||
|
p = (k.shape[0] - factor) + (convW - 1)
|
||||||
|
s = [factor, factor]
|
||||||
|
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
|
||||||
|
return F.conv2d(x, w, stride=s, padding=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||||
|
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||||
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||||
|
order.
|
||||||
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||||
|
C]`.
|
||||||
|
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||||
|
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||||
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||||
|
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||||
|
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
||||||
|
`x`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(factor, int) and factor >= 1
|
||||||
|
|
||||||
|
# Check weight shape.
|
||||||
|
assert len(w.shape) == 4
|
||||||
|
convH = w.shape[2]
|
||||||
|
convW = w.shape[3]
|
||||||
|
inC = w.shape[1]
|
||||||
|
|
||||||
|
assert convW == convH
|
||||||
|
|
||||||
|
# Setup filter kernel.
|
||||||
|
if k is None:
|
||||||
|
k = [1] * factor
|
||||||
|
k = _setup_kernel(k) * (gain * (factor**2))
|
||||||
|
p = (k.shape[0] - factor) - (convW - 1)
|
||||||
|
|
||||||
|
stride = (factor, factor)
|
||||||
|
|
||||||
|
# Determine data dimensions.
|
||||||
|
stride = [1, 1, factor, factor]
|
||||||
|
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||||
|
output_padding = (
|
||||||
|
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||||
|
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||||
|
)
|
||||||
|
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||||
|
num_groups = x.shape[1] // inC
|
||||||
|
|
||||||
|
# Transpose weights.
|
||||||
|
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||||
|
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
||||||
|
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||||
|
|
||||||
|
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||||
|
|
||||||
|
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||||
|
|
||||||
|
|
||||||
# TODO (patil-suraj): needs test
|
# TODO (patil-suraj): needs test
|
||||||
# class Upsample1d(nn.Module):
|
# class Upsample2D1d(nn.Module):
|
||||||
# def __init__(self, dim):
|
# def __init__(self, dim):
|
||||||
# super().__init__()
|
# super().__init__()
|
||||||
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||||
|
@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module):
|
||||||
elif kernel == "sde_vp":
|
elif kernel == "sde_vp":
|
||||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||||
else:
|
else:
|
||||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
self.upsample = Upsample2D(in_channels, use_conv=False)
|
||||||
elif self.down:
|
elif self.down:
|
||||||
if kernel == "fir":
|
if kernel == "fir":
|
||||||
fir_kernel = (1, 3, 3, 1)
|
fir_kernel = (1, 3, 3, 1)
|
||||||
|
@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module):
|
||||||
elif kernel == "sde_vp":
|
elif kernel == "sde_vp":
|
||||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||||
else:
|
else:
|
||||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
||||||
|
|
||||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||||
|
|
||||||
|
@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.res_conv = torch.nn.Identity()
|
self.res_conv = torch.nn.Identity()
|
||||||
elif self.overwrite_for_ldm:
|
elif self.overwrite_for_ldm:
|
||||||
dims = 2
|
|
||||||
channels = in_channels
|
channels = in_channels
|
||||||
emb_channels = temb_channels
|
emb_channels = temb_channels
|
||||||
use_scale_shift_norm = False
|
use_scale_shift_norm = False
|
||||||
|
@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module):
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels, swish=1.0),
|
normalization(channels, swish=1.0),
|
||||||
nn.Identity(),
|
nn.Identity(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
||||||
)
|
)
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
|
@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module):
|
||||||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
if self.out_channels == in_channels:
|
if self.out_channels == in_channels:
|
||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
||||||
elif self.overwrite_for_score_vde:
|
elif self.overwrite_for_score_vde:
|
||||||
in_ch = in_channels
|
in_ch = in_channels
|
||||||
out_ch = out_channels
|
out_ch = out_channels
|
||||||
|
@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
||||||
|
|
||||||
|
|
||||||
def upsample_2d(x, k=None, factor=2, gain=1):
|
def upsample_2d(x, k=None, factor=2, gain=1):
|
||||||
r"""Upsample a batch of 2D images with the given filter.
|
r"""Upsample2D a batch of 2D images with the given filter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||||
|
@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1):
|
||||||
|
|
||||||
|
|
||||||
def downsample_2d(x, k=None, factor=2, gain=1):
|
def downsample_2d(x, k=None, factor=2, gain=1):
|
||||||
r"""Downsample a batch of 2D images with the given filter.
|
r"""Downsample2D a batch of 2D images with the given filter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import AttentionBlock
|
from .attention import AttentionBlock
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
|
@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
|
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
|
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||||
curr_res = curr_res * 2
|
curr_res = curr_res * 2
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import AttentionBlock
|
from .attention import AttentionBlock
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
def convert_module_to_f16(l):
|
def convert_module_to_f16(l):
|
||||||
|
@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||||
down=True,
|
down=True,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
|
@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||||
up=True,
|
up=True,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
|
|
@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import LinearAttention
|
from .attention import LinearAttention
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
class Mish(torch.nn.Module):
|
class Mish(torch.nn.Module):
|
||||||
|
@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
overwrite_for_grad_tts=True,
|
overwrite_for_grad_tts=True,
|
||||||
),
|
),
|
||||||
Residual(Rezero(LinearAttention(dim_out))),
|
Residual(Rezero(LinearAttention(dim_out))),
|
||||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
overwrite_for_grad_tts=True,
|
overwrite_for_grad_tts=True,
|
||||||
),
|
),
|
||||||
Residual(Rezero(LinearAttention(dim_in))),
|
Residual(Rezero(LinearAttention(dim_in))),
|
||||||
Upsample(dim_in, use_conv_transpose=True),
|
Upsample2D(dim_in, use_conv_transpose=True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import AttentionBlock
|
from .attention import AttentionBlock
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
# from .resnet import ResBlock
|
# from .resnet import ResBlock
|
||||||
|
@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
|
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
|
@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
if level and i == num_res_blocks:
|
if level and i == num_res_blocks:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
|
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .embeddings import get_timestep_embedding
|
from .embeddings import get_timestep_embedding
|
||||||
from .resnet import Downsample, ResidualTemporalBlock, Upsample
|
from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
[
|
[
|
||||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||||
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
|
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
[
|
[
|
||||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||||
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
|
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,13 +21,12 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import AttentionBlock
|
from .attention import AttentionBlock
|
||||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
def _setup_kernel(k):
|
def _setup_kernel(k):
|
||||||
|
@ -40,96 +39,6 @@ def _setup_kernel(k):
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
|
||||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
|
||||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
|
||||||
order.
|
|
||||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
||||||
C]`.
|
|
||||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
|
||||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
|
||||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
||||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
|
||||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
|
||||||
`x`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert isinstance(factor, int) and factor >= 1
|
|
||||||
|
|
||||||
# Check weight shape.
|
|
||||||
assert len(w.shape) == 4
|
|
||||||
convH = w.shape[2]
|
|
||||||
convW = w.shape[3]
|
|
||||||
inC = w.shape[1]
|
|
||||||
|
|
||||||
assert convW == convH
|
|
||||||
|
|
||||||
# Setup filter kernel.
|
|
||||||
if k is None:
|
|
||||||
k = [1] * factor
|
|
||||||
k = _setup_kernel(k) * (gain * (factor**2))
|
|
||||||
p = (k.shape[0] - factor) - (convW - 1)
|
|
||||||
|
|
||||||
stride = (factor, factor)
|
|
||||||
|
|
||||||
# Determine data dimensions.
|
|
||||||
stride = [1, 1, factor, factor]
|
|
||||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
|
||||||
output_padding = (
|
|
||||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
|
||||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
|
||||||
)
|
|
||||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
|
||||||
num_groups = x.shape[1] // inC
|
|
||||||
|
|
||||||
# Transpose weights.
|
|
||||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
|
||||||
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
|
||||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
|
||||||
|
|
||||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
|
||||||
|
|
||||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
|
||||||
|
|
||||||
|
|
||||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
|
||||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
|
||||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
|
||||||
order.
|
|
||||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
||||||
C]`.
|
|
||||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
|
||||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
|
||||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
||||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
|
||||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
|
|
||||||
as `x`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert isinstance(factor, int) and factor >= 1
|
|
||||||
_outC, _inC, convH, convW = w.shape
|
|
||||||
assert convW == convH
|
|
||||||
if k is None:
|
|
||||||
k = [1] * factor
|
|
||||||
k = _setup_kernel(k) * gain
|
|
||||||
p = (k.shape[0] - factor) + (convW - 1)
|
|
||||||
s = [factor, factor]
|
|
||||||
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
|
|
||||||
return F.conv2d(x, w, stride=s, padding=0)
|
|
||||||
|
|
||||||
|
|
||||||
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||||
"""Ported from JAX."""
|
"""Ported from JAX."""
|
||||||
scale = 1e-10 if scale == 0 else scale
|
scale = 1e-10 if scale == 0 else scale
|
||||||
|
@ -183,46 +92,6 @@ class Combine(nn.Module):
|
||||||
raise ValueError(f"Method {self.method} not recognized.")
|
raise ValueError(f"Method {self.method} not recognized.")
|
||||||
|
|
||||||
|
|
||||||
class FirUpsample(nn.Module):
|
|
||||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
|
||||||
super().__init__()
|
|
||||||
out_channels = out_channels if out_channels else channels
|
|
||||||
if use_conv:
|
|
||||||
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.use_conv = use_conv
|
|
||||||
self.fir_kernel = fir_kernel
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.use_conv:
|
|
||||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
|
||||||
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
|
||||||
else:
|
|
||||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class FirDownsample(nn.Module):
|
|
||||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
|
||||||
super().__init__()
|
|
||||||
out_channels = out_channels if out_channels else channels
|
|
||||||
if use_conv:
|
|
||||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.fir_kernel = fir_kernel
|
|
||||||
self.use_conv = use_conv
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.use_conv:
|
|
||||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
|
||||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
|
||||||
else:
|
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class NCSNpp(ModelMixin, ConfigMixin):
|
class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
"""NCSN++ model"""
|
"""NCSN++ model"""
|
||||||
|
|
||||||
|
@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||||
|
|
||||||
if self.fir:
|
if self.fir:
|
||||||
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||||
else:
|
else:
|
||||||
Up_sample = functools.partial(Upsample, name="Conv2d_0")
|
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
|
||||||
|
|
||||||
if progressive == "output_skip":
|
if progressive == "output_skip":
|
||||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||||
|
@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||||
|
|
||||||
if self.fir:
|
if self.fir:
|
||||||
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||||
else:
|
else:
|
||||||
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
|
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
|
||||||
|
|
||||||
if progressive_input == "input_skip":
|
if progressive_input == "input_skip":
|
||||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from .attention import AttentionBlock
|
from .attention import AttentionBlock
|
||||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
|
@ -65,7 +65,7 @@ class Encoder(nn.Module):
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
|
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ class Decoder(nn.Module):
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
|
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||||
curr_res = curr_res * 2
|
curr_res = curr_res * 2
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ class ResidualBlock(nn.Module):
|
||||||
# Dilated conv layer
|
# Dilated conv layer
|
||||||
h = self.dilated_conv_layer(h)
|
h = self.dilated_conv_layer(h)
|
||||||
|
|
||||||
# Upsample spectrogram to size of audio
|
# Upsample2D spectrogram to size of audio
|
||||||
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
||||||
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
||||||
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.embeddings import get_timestep_embedding
|
from diffusers.models.embeddings import get_timestep_embedding
|
||||||
from diffusers.models.resnet import Downsample, Upsample
|
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlockTests(unittest.TestCase):
|
class Upsample2DBlockTests(unittest.TestCase):
|
||||||
def test_upsample_default(self):
|
def test_upsample_default(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 32, 32)
|
sample = torch.randn(1, 32, 32, 32)
|
||||||
upsample = Upsample(channels=32, use_conv=False)
|
upsample = Upsample2D(channels=32, use_conv=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
upsampled = upsample(sample)
|
upsampled = upsample(sample)
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
||||||
def test_upsample_with_conv(self):
|
def test_upsample_with_conv(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 32, 32)
|
sample = torch.randn(1, 32, 32, 32)
|
||||||
upsample = Upsample(channels=32, use_conv=True)
|
upsample = Upsample2D(channels=32, use_conv=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
upsampled = upsample(sample)
|
upsampled = upsample(sample)
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
||||||
def test_upsample_with_conv_out_dim(self):
|
def test_upsample_with_conv_out_dim(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 32, 32)
|
sample = torch.randn(1, 32, 32, 32)
|
||||||
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
|
upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
upsampled = upsample(sample)
|
upsampled = upsample(sample)
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
||||||
def test_upsample_with_transpose(self):
|
def test_upsample_with_transpose(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 32, 32)
|
sample = torch.randn(1, 32, 32, 32)
|
||||||
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
|
upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
upsampled = upsample(sample)
|
upsampled = upsample(sample)
|
||||||
|
|
||||||
|
@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase):
|
||||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
class DownsampleBlockTests(unittest.TestCase):
|
class Downsample2DBlockTests(unittest.TestCase):
|
||||||
def test_downsample_default(self):
|
def test_downsample_default(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 64, 64)
|
sample = torch.randn(1, 32, 64, 64)
|
||||||
downsample = Downsample(channels=32, use_conv=False)
|
downsample = Downsample2D(channels=32, use_conv=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
downsampled = downsample(sample)
|
downsampled = downsample(sample)
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
||||||
def test_downsample_with_conv(self):
|
def test_downsample_with_conv(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 64, 64)
|
sample = torch.randn(1, 32, 64, 64)
|
||||||
downsample = Downsample(channels=32, use_conv=True)
|
downsample = Downsample2D(channels=32, use_conv=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
downsampled = downsample(sample)
|
downsampled = downsample(sample)
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
||||||
def test_downsample_with_conv_pad1(self):
|
def test_downsample_with_conv_pad1(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 64, 64)
|
sample = torch.randn(1, 32, 64, 64)
|
||||||
downsample = Downsample(channels=32, use_conv=True, padding=1)
|
downsample = Downsample2D(channels=32, use_conv=True, padding=1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
downsampled = downsample(sample)
|
downsampled = downsample(sample)
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
||||||
def test_downsample_with_conv_out_dim(self):
|
def test_downsample_with_conv_out_dim(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
sample = torch.randn(1, 32, 64, 64)
|
sample = torch.randn(1, 32, 64, 64)
|
||||||
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
|
downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
downsampled = downsample(sample)
|
downsampled = downsample(sample)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue