diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e55b83e9..22582853 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -637,62 +637,6 @@ class ResnetBlockBigGANpp(nn.Module): return (x + h) / np.sqrt(2.0) -# unet_score_estimation.py -class ResnetBlockDDPMpp(nn.Module): - """ResBlock adapted from DDPM.""" - - def __init__( - self, - act, - in_ch, - out_ch=None, - temb_dim=None, - conv_shortcut=False, - dropout=0.1, - skip_rescale=False, - init_scale=0.0, - ): - super().__init__() - out_ch = out_ch if out_ch else in_ch - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) - self.Conv_0 = conv3x3(in_ch, out_ch) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) - nn.init.zeros_(self.Dense_0.bias) - self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) - if in_ch != out_ch: - if conv_shortcut: - self.Conv_2 = conv3x3(in_ch, out_ch) - else: - self.NIN_0 = NIN(in_ch, out_ch) - - self.skip_rescale = skip_rescale - self.act = act - self.out_ch = out_ch - self.conv_shortcut = conv_shortcut - - def forward(self, x, temb=None): - h = self.act(self.GroupNorm_0(x)) - h = self.Conv_0(h) - if temb is not None: - h += self.Dense_0(self.act(temb))[:, :, None, None] - h = self.act(self.GroupNorm_1(h)) - h = self.Dropout_0(h) - h = self.Conv_1(h) - if x.shape[1] != self.out_ch: - if self.conv_shortcut: - x = self.Conv_2(x) - else: - x = self.NIN_0(x) - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) - - # unet_rl.py class ResidualTemporalBlock(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): @@ -957,18 +901,6 @@ def downsample_2d(x, k=None, factor=2, gain=1): return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) -class NIN(nn.Module): - def __init__(self, in_dim, num_units, init_scale=0.1): - super().__init__() - self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) - self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - y = contract_inner(x, self.W) + self.b - return y.permute(0, 3, 1, 2) - - def _setup_kernel(k): k = np.asarray(k, dtype=np.float32) if k.ndim == 1: diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index dda14445..890678f1 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -28,7 +28,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import ResnetBlockBigGANpp, ResnetBlockDDPMpp +from .resnet import ResnetBlockBigGANpp def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -305,32 +305,6 @@ def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1 return conv -def _einsum(a, b, c, x, y): - einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) - return torch.einsum(einsum_str, x, y) - - -def contract_inner(x, y): - """tensordot(x, y, 1).""" - x_chars = list(string.ascii_lowercase[: len(x.shape)]) - y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) - y_chars[0] = x_chars[-1] # first axis of y and last of x get summed - out_chars = x_chars[:-1] + y_chars[1:] - return _einsum(x_chars, y_chars, out_chars, x, y) - - -class NIN(nn.Module): - def __init__(self, in_dim, num_units, init_scale=0.1): - super().__init__() - self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) - self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - y = contract_inner(x, self.W) + self.b - return y.permute(0, 3, 1, 2) - - def get_act(nonlinearity): """Get activation functions from the config file.""" @@ -575,30 +549,16 @@ class NCSNpp(ModelMixin, ConfigMixin): elif progressive_input == "residual": pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) - if resblock_type == "ddpm": - ResnetBlock = functools.partial( - ResnetBlockDDPMpp, - act=act, - dropout=dropout, - init_scale=init_scale, - skip_rescale=skip_rescale, - temb_dim=nf * 4, - ) - - elif resblock_type == "biggan": - ResnetBlock = functools.partial( - ResnetBlockBigGANpp, - act=act, - dropout=dropout, - fir=fir, - fir_kernel=fir_kernel, - init_scale=init_scale, - skip_rescale=skip_rescale, - temb_dim=nf * 4, - ) - - else: - raise ValueError(f"resblock type {resblock_type} unrecognized.") + ResnetBlock = functools.partial( + ResnetBlockBigGANpp, + act=act, + dropout=dropout, + fir=fir, + fir_kernel=fir_kernel, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) # Downsampling block @@ -622,10 +582,7 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) if i_level != self.num_resolutions - 1: - if resblock_type == "ddpm": - modules.append(Downsample(in_ch=in_ch)) - else: - modules.append(ResnetBlock(down=True, in_ch=in_ch)) + modules.append(ResnetBlock(down=True, in_ch=in_ch)) if progressive_input == "input_skip": modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) @@ -678,10 +635,7 @@ class NCSNpp(ModelMixin, ConfigMixin): raise ValueError(f"{progressive} is not a valid name") if i_level != 0: - if resblock_type == "ddpm": - modules.append(Upsample(in_ch=in_ch)) - else: - modules.append(ResnetBlock(in_ch=in_ch, up=True)) + modules.append(ResnetBlock(in_ch=in_ch, up=True)) assert not hs_c @@ -741,12 +695,8 @@ class NCSNpp(ModelMixin, ConfigMixin): hs.append(h) if i_level != self.num_resolutions - 1: - if self.resblock_type == "ddpm": - h = modules[m_idx](hs[-1]) - m_idx += 1 - else: - h = modules[m_idx](hs[-1], temb) - m_idx += 1 + h = modules[m_idx](hs[-1], temb) + m_idx += 1 if self.progressive_input == "input_skip": input_pyramid = self.pyramid_downsample(input_pyramid) @@ -818,12 +768,8 @@ class NCSNpp(ModelMixin, ConfigMixin): raise ValueError(f"{self.progressive} is not a valid name") if i_level != 0: - if self.resblock_type == "ddpm": - h = modules[m_idx](h) - m_idx += 1 - else: - h = modules[m_idx](h, temb) - m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 assert not hs