Downsample / Upsample - clean to 1D and 2D (#68)

* make unet rl work

* uploaad files / code

* upload files

* make style correct

* finish
This commit is contained in:
Patrick von Platen 2022-07-03 22:26:33 +02:00 committed by GitHub
parent c524244f49
commit 321f9791d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 254 additions and 232 deletions

View File

@ -6,46 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Upsample(nn.Module):
class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
@ -54,21 +15,21 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
@ -79,11 +40,9 @@ class Upsample(nn.Module):
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
x = self.conv(x)
@ -93,7 +52,7 @@ class Upsample(nn.Module):
return x
class Downsample(nn.Module):
class Downsample2D(nn.Module):
"""
A downsampling layer with an optional convolution.
@ -102,22 +61,22 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
stride = 2
self.name = name
if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
elif name == "Conv2d_0":
@ -127,10 +86,11 @@ class Downsample(nn.Module):
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0":
@ -139,8 +99,204 @@ class Downsample(nn.Module):
return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# class Upsample2D1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module):
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module):
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
nn.Conv2d(channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module):
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
elif self.overwrite_for_score_vde:
in_ch = in_channels
out_ch = out_channels
@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
r"""Upsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1):
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
r"""Downsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the

View File

@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x):
@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def convert_module_to_f16(l):
@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down=True,
)
if resblock_updown
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True,
)
if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))

View File

@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
class Mish(torch.nn.Module):
@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
)
)
@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True),
Upsample2D(dim_in, use_conv_transpose=True),
]
)
)

View File

@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
# from .resnet import ResBlock
@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResidualTemporalBlock, Upsample
from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
class SinusoidalPosEmb(nn.Module):
@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
]
)
)
@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
]
)
)

View File

@ -21,13 +21,12 @@ import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def _setup_kernel(k):
@ -40,96 +39,6 @@ def _setup_kernel(k):
return k
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
@ -183,46 +92,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.")
class FirUpsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model"""
@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .resnet import Downsample, ResnetBlock2D, Upsample
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x):
@ -65,7 +65,7 @@ class Encoder(nn.Module):
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
@ -179,7 +179,7 @@ class Decoder(nn.Module):
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@ -137,7 +137,7 @@ class ResidualBlock(nn.Module):
# Dilated conv layer
h = self.dilated_conv_layer(h)
# Upsample spectrogram to size of audio
# Upsample2D spectrogram to size of audio
mel_spec = torch.unsqueeze(mel_spec, dim=1)
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)

View File

@ -22,7 +22,7 @@ import numpy as np
import torch
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample, Upsample
from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase):
)
class UpsampleBlockTests(unittest.TestCase):
class Upsample2DBlockTests(unittest.TestCase):
def test_upsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False)
upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True)
upsample = Upsample2D(channels=32, use_conv=True)
with torch.no_grad():
upsampled = upsample(sample)
@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
with torch.no_grad():
upsampled = upsample(sample)
@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_transpose(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad():
upsampled = upsample(sample)
@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class DownsampleBlockTests(unittest.TestCase):
class Downsample2DBlockTests(unittest.TestCase):
def test_downsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=False)
downsample = Downsample2D(channels=32, use_conv=False)
with torch.no_grad():
downsampled = downsample(sample)
@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True)
downsample = Downsample2D(channels=32, use_conv=True)
with torch.no_grad():
downsampled = downsample(sample)
@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_pad1(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, padding=1)
downsample = Downsample2D(channels=32, use_conv=True, padding=1)
with torch.no_grad():
downsampled = downsample(sample)
@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, out_channels=16)
downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
with torch.no_grad():
downsampled = downsample(sample)