Downsample / Upsample - clean to 1D and 2D (#68)
* make unet rl work * uploaad files / code * upload files * make style correct * finish
This commit is contained in:
parent
c524244f49
commit
321f9791d6
|
@ -6,46 +6,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_transpose_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.ConvTranspose1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.ConvTranspose2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.ConvTranspose3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
class Upsample2D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
|
@ -54,21 +15,21 @@ class Upsample(nn.Module):
|
|||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
|
@ -79,11 +40,9 @@ class Upsample(nn.Module):
|
|||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
x = self.conv(x)
|
||||
|
@ -93,7 +52,7 @@ class Upsample(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
class Downsample2D(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
|
@ -102,22 +61,22 @@ class Downsample(nn.Module):
|
|||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.padding = padding
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
|
@ -127,10 +86,11 @@ class Downsample(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0 and self.dims == 2:
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
elif self.name == "Conv2d_0":
|
||||
|
@ -139,8 +99,204 @@ class Downsample(nn.Module):
|
|||
return self.op(x)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
|
||||
as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
_outC, _inC, convH, convW = w.shape
|
||||
assert convW == convH
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = (k.shape[0] - factor) + (convW - 1)
|
||||
s = [factor, factor]
|
||||
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
|
||||
return F.conv2d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
||||
`x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Check weight shape.
|
||||
assert len(w.shape) == 4
|
||||
convH = w.shape[2]
|
||||
convW = w.shape[3]
|
||||
inC = w.shape[1]
|
||||
|
||||
assert convW == convH
|
||||
|
||||
# Setup filter kernel.
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor**2))
|
||||
p = (k.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = x.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
# TODO (patil-suraj): needs test
|
||||
# class Upsample1d(nn.Module):
|
||||
# class Upsample2D1d(nn.Module):
|
||||
# def __init__(self, dim):
|
||||
# super().__init__()
|
||||
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module):
|
|||
elif kernel == "sde_vp":
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
||||
self.upsample = Upsample2D(in_channels, use_conv=False)
|
||||
elif self.down:
|
||||
if kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
|
@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module):
|
|||
elif kernel == "sde_vp":
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
else:
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
||||
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
|
@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module):
|
|||
else:
|
||||
self.res_conv = torch.nn.Identity()
|
||||
elif self.overwrite_for_ldm:
|
||||
dims = 2
|
||||
channels = in_channels
|
||||
emb_channels = temb_channels
|
||||
use_scale_shift_norm = False
|
||||
|
@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
|
@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module):
|
|||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.out_channels == in_channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
||||
elif self.overwrite_for_score_vde:
|
||||
in_ch = in_channels
|
||||
out_ch = out_channels
|
||||
|
@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
|||
|
||||
|
||||
def upsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Upsample a batch of 2D images with the given filter.
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
|
@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1):
|
|||
|
||||
|
||||
def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Downsample a batch of 2D images with the given filter.
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
|
|
|
@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
|
@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
|
||||
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
|
@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
)
|
||||
else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
|
|
@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in, use_conv_transpose=True),
|
||||
Upsample2D(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
# from .resnet import ResBlock
|
||||
|
@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
|
||||
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
|
@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
|
||||
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResidualTemporalBlock, Upsample
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
|
@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
|
||||
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
|
||||
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -21,13 +21,12 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
|
@ -40,96 +39,6 @@ def _setup_kernel(k):
|
|||
return k
|
||||
|
||||
|
||||
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
||||
`x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Check weight shape.
|
||||
assert len(w.shape) == 4
|
||||
convH = w.shape[2]
|
||||
convW = w.shape[3]
|
||||
inC = w.shape[1]
|
||||
|
||||
assert convW == convH
|
||||
|
||||
# Setup filter kernel.
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor**2))
|
||||
p = (k.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = x.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
||||
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
|
||||
as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
_outC, _inC, convH, convW = w.shape
|
||||
assert convW == convH
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = (k.shape[0] - factor) + (convW - 1)
|
||||
s = [factor, factor]
|
||||
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
|
||||
return F.conv2d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
|
@ -183,46 +92,6 @@ class Combine(nn.Module):
|
|||
raise ValueError(f"Method {self.method} not recognized.")
|
||||
|
||||
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NCSNpp(ModelMixin, ConfigMixin):
|
||||
"""NCSN++ model"""
|
||||
|
||||
|
@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
|
||||
if self.fir:
|
||||
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Up_sample = functools.partial(Upsample, name="Conv2d_0")
|
||||
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||
|
@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||
|
||||
if self.fir:
|
||||
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
|
||||
else:
|
||||
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
|
||||
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .resnet import Downsample, ResnetBlock2D, Upsample
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
@ -65,7 +65,7 @@ class Encoder(nn.Module):
|
|||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
|
@ -179,7 +179,7 @@ class Decoder(nn.Module):
|
|||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
|
||||
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ class ResidualBlock(nn.Module):
|
|||
# Dilated conv layer
|
||||
h = self.dilated_conv_layer(h)
|
||||
|
||||
# Upsample spectrogram to size of audio
|
||||
# Upsample2D spectrogram to size of audio
|
||||
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
||||
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
||||
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Downsample, Upsample
|
||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
|
@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class UpsampleBlockTests(unittest.TestCase):
|
||||
class Upsample2DBlockTests(unittest.TestCase):
|
||||
def test_upsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False)
|
||||
upsample = Upsample2D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
|
@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
|||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True)
|
||||
upsample = Upsample2D(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
|
@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
|||
def test_upsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
|
||||
upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
|
@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase):
|
|||
def test_upsample_with_transpose(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
|
||||
upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
|
@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase):
|
|||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class DownsampleBlockTests(unittest.TestCase):
|
||||
class Downsample2DBlockTests(unittest.TestCase):
|
||||
def test_downsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=False)
|
||||
downsample = Downsample2D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
|
@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
|||
def test_downsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True)
|
||||
downsample = Downsample2D(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
|
@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
|||
def test_downsample_with_conv_pad1(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True, padding=1)
|
||||
downsample = Downsample2D(channels=32, use_conv=True, padding=1)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
|
@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase):
|
|||
def test_downsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64, 64)
|
||||
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
|
||||
downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
|
|
Loading…
Reference in New Issue