final fix

This commit is contained in:
Patrick von Platen 2022-06-28 22:59:21 +00:00
parent 635da72374
commit 31d1f3c8c0
5 changed files with 30 additions and 19 deletions

View File

@ -91,11 +91,15 @@ class AttentionBlock(nn.Module):
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0]
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
@ -107,14 +111,19 @@ class AttentionBlock(nn.Module):
self.proj_out = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten:
if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
self.set_weights(self)
self.is_overwritten = True
@ -152,7 +161,7 @@ class AttentionBlock(nn.Module):
# unet_score_estimation.py
#class AttnBlockpp(nn.Module):
# class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
#
# def __init__(
@ -187,14 +196,11 @@ class AttentionBlock(nn.Module):
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
# self.qkv = conv_nd(1, channels, channels * 3, 1)
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
#
# self.is_weight_set = False
#
@ -205,6 +211,9 @@ class AttentionBlock(nn.Module):
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
#
# def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
@ -261,6 +270,7 @@ class AttentionBlock(nn.Module):
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""

View File

@ -30,9 +30,9 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .attention2d import AttentionBlock
def nonlinearity(x):
@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum())
# print("Result", (h - h_2).abs().sum())
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))

View File

@ -3,9 +3,9 @@ from numpy import pad
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .attention2d import LinearAttention
class Mish(torch.nn.Module):

View File

@ -16,18 +16,18 @@
# helpers functions
import functools
import math
import string
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .attention2d import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive == "output_skip":

View File

@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105])
expected_slice = torch.tensor(
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow