Downsample / Upsample - clean to 1D and 2D (#68)

* make unet rl work

* uploaad files / code

* upload files

* make style correct

* finish
This commit is contained in:
Patrick von Platen 2022-07-03 22:26:33 +02:00 committed by GitHub
parent c524244f49
commit 321f9791d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 254 additions and 232 deletions

View File

@ -6,46 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs): class Upsample2D(nn.Module):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Upsample(nn.Module):
""" """
An upsampling layer with an optional convolution. An upsampling layer with an optional convolution.
@ -54,21 +15,21 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
conv = None conv = None
if use_conv_transpose: if use_conv_transpose:
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
else: else:
@ -79,11 +40,9 @@ class Upsample(nn.Module):
if self.use_conv_transpose: if self.use_conv_transpose:
return self.conv(x) return self.conv(x)
if self.dims == 3: x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
x = self.conv(x) x = self.conv(x)
@ -93,7 +52,7 @@ class Upsample(nn.Module):
return x return x
class Downsample(nn.Module): class Downsample2D(nn.Module):
""" """
A downsampling layer with an optional convolution. A downsampling layer with an optional convolution.
@ -102,22 +61,22 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims
self.padding = padding self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2) stride = 2
self.name = name self.name = name
if use_conv: if use_conv:
conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = avg_pool_nd(dims, kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0": elif name == "Conv2d_0":
@ -127,10 +86,11 @@ class Downsample(nn.Module):
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0": elif self.name == "Conv2d_0":
@ -139,8 +99,204 @@ class Downsample(nn.Module):
return self.op(x) return self.op(x)
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class FirUpsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
# TODO (patil-suraj): needs test # TODO (patil-suraj): needs test
# class Upsample1d(nn.Module): # class Upsample2D1d(nn.Module):
# def __init__(self, dim): # def __init__(self, dim):
# super().__init__() # super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) # self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
@ -221,7 +377,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else: else:
self.upsample = Upsample(in_channels, use_conv=False, dims=2) self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down: elif self.down:
if kernel == "fir": if kernel == "fir":
fir_kernel = (1, 3, 3, 1) fir_kernel = (1, 3, 3, 1)
@ -229,7 +385,7 @@ class ResnetBlock2D(nn.Module):
elif kernel == "sde_vp": elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else: else:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
@ -257,7 +413,6 @@ class ResnetBlock2D(nn.Module):
else: else:
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm: elif self.overwrite_for_ldm:
dims = 2
channels = in_channels channels = in_channels
emb_channels = temb_channels emb_channels = temb_channels
use_scale_shift_norm = False use_scale_shift_norm = False
@ -266,7 +421,7 @@ class ResnetBlock2D(nn.Module):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels, swish=1.0), normalization(channels, swish=1.0),
nn.Identity(), nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1), nn.Conv2d(channels, self.out_channels, 3, padding=1),
) )
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
@ -279,12 +434,12 @@ class ResnetBlock2D(nn.Module):
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(), nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
) )
if self.out_channels == in_channels: if self.out_channels == in_channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
elif self.overwrite_for_score_vde: elif self.overwrite_for_score_vde:
in_ch = in_channels in_ch = in_channels
out_ch = out_channels out_ch = out_channels
@ -631,7 +786,7 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
def upsample_2d(x, k=None, factor=2, gain=1): def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Args: Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@ -656,7 +811,7 @@ def upsample_2d(x, k=None, factor=2, gain=1):
def downsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Args: Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the

View File

@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x): def nonlinearity(x):
@ -100,7 +100,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down.block = block down.block = block
down.attn = attn down.attn = attn
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
@ -139,7 +139,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv) up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order

View File

@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def convert_module_to_f16(l): def convert_module_to_f16(l):
@ -218,9 +218,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
)
) )
) )
ch = out_ch ch = out_ch
@ -299,7 +297,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))

View File

@ -4,7 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
@ -105,7 +105,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
) )
) )
@ -158,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
), ),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample2D(dim_in, use_conv_transpose=True),
] ]
) )
) )

View File

@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
# from .resnet import ResBlock # from .resnet import ResBlock
@ -350,7 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op") Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
) )
) )
ch = out_ch ch = out_ch
@ -437,7 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)) layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResidualTemporalBlock, Upsample from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
@ -96,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(), Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
] ]
) )
) )
@ -116,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[ [
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(), Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
] ]
) )
) )

View File

@ -21,13 +21,12 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlock2D, Upsample, downsample_2d, upfirdn2d, upsample_2d from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def _setup_kernel(k): def _setup_kernel(k):
@ -40,96 +39,6 @@ def _setup_kernel(k):
return k return k
def _upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def _conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): def _variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
@ -183,46 +92,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class FirUpsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
return h
class FirDownsample(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
return x
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model""" """NCSN++ model"""
@ -313,9 +182,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir: if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
Up_sample = functools.partial(Upsample, name="Conv2d_0") Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
@ -323,9 +192,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
pyramid_upsample = functools.partial(Up_sample, use_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir: if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv) Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .resnet import Downsample, ResnetBlock2D, Upsample from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x): def nonlinearity(x):
@ -65,7 +65,7 @@ class Encoder(nn.Module):
down.block = block down.block = block
down.attn = attn down.attn = attn
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
@ -179,7 +179,7 @@ class Decoder(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
up.upsample = Upsample(block_in, use_conv=resamp_with_conv) up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2 curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order self.up.insert(0, up) # prepend to get consistent order

View File

@ -137,7 +137,7 @@ class ResidualBlock(nn.Module):
# Dilated conv layer # Dilated conv layer
h = self.dilated_conv_layer(h) h = self.dilated_conv_layer(h)
# Upsample spectrogram to size of audio # Upsample2D spectrogram to size of audio
mel_spec = torch.unsqueeze(mel_spec, dim=1) mel_spec = torch.unsqueeze(mel_spec, dim=1)
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False) mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False) mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)

View File

@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample, Upsample from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -116,11 +116,11 @@ class EmbeddingsTests(unittest.TestCase):
) )
class UpsampleBlockTests(unittest.TestCase): class Upsample2DBlockTests(unittest.TestCase):
def test_upsample_default(self): def test_upsample_default(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False) upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
@ -132,7 +132,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv(self): def test_upsample_with_conv(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True) upsample = Upsample2D(channels=32, use_conv=True)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
@ -144,7 +144,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_conv_out_dim(self): def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64) upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
@ -156,7 +156,7 @@ class UpsampleBlockTests(unittest.TestCase):
def test_upsample_with_transpose(self): def test_upsample_with_transpose(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32) sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True) upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad(): with torch.no_grad():
upsampled = upsample(sample) upsampled = upsample(sample)
@ -166,11 +166,11 @@ class UpsampleBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class DownsampleBlockTests(unittest.TestCase): class Downsample2DBlockTests(unittest.TestCase):
def test_downsample_default(self): def test_downsample_default(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=False) downsample = Downsample2D(channels=32, use_conv=False)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
@ -184,7 +184,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv(self): def test_downsample_with_conv(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True) downsample = Downsample2D(channels=32, use_conv=True)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
@ -199,7 +199,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_pad1(self): def test_downsample_with_conv_pad1(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, padding=1) downsample = Downsample2D(channels=32, use_conv=True, padding=1)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)
@ -211,7 +211,7 @@ class DownsampleBlockTests(unittest.TestCase):
def test_downsample_with_conv_out_dim(self): def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0) torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64) sample = torch.randn(1, 32, 64, 64)
downsample = Downsample(channels=32, use_conv=True, out_channels=16) downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
with torch.no_grad(): with torch.no_grad():
downsampled = downsample(sample) downsampled = downsample(sample)