final fix
This commit is contained in:
parent
635da72374
commit
31d1f3c8c0
|
@ -91,11 +91,15 @@ class AttentionBlock(nn.Module):
|
|||
self.NIN_2 = NIN(channels, channels)
|
||||
self.NIN_3 = NIN(channels, channels)
|
||||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights(self, module):
|
||||
if self.overwrite_qkv:
|
||||
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0]
|
||||
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
|
||||
:, :, :, 0
|
||||
]
|
||||
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
|
||||
|
||||
self.qkv.weight.data = qkv_weight
|
||||
|
@ -107,14 +111,19 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
self.proj_out = proj_out
|
||||
elif self.overwrite_linear:
|
||||
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
|
||||
self.qkv.weight.data = torch.concat(
|
||||
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
|
||||
)[:, :, None]
|
||||
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
|
||||
|
||||
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
self.proj_out.bias.data = self.NIN_3.b.data
|
||||
|
||||
self.norm.weight.data = self.GroupNorm_0.weight.data
|
||||
self.norm.bias.data = self.GroupNorm_0.bias.data
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
if self.overwrite_qkv and not self.is_overwritten:
|
||||
if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
|
||||
self.set_weights(self)
|
||||
self.is_overwritten = True
|
||||
|
||||
|
@ -152,7 +161,7 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
|
||||
# unet_score_estimation.py
|
||||
#class AttnBlockpp(nn.Module):
|
||||
# class AttnBlockpp(nn.Module):
|
||||
# """Channel-wise self-attention block. Modified from DDPM."""
|
||||
#
|
||||
# def __init__(
|
||||
|
@ -187,14 +196,11 @@ class AttentionBlock(nn.Module):
|
|||
# self.num_heads = channels // num_head_channels
|
||||
#
|
||||
# self.use_checkpoint = use_checkpoint
|
||||
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
|
||||
# self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
||||
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
# self.n_heads = self.num_heads
|
||||
#
|
||||
# if encoder_channels is not None:
|
||||
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
|
||||
#
|
||||
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
#
|
||||
# self.is_weight_set = False
|
||||
#
|
||||
|
@ -205,6 +211,9 @@ class AttentionBlock(nn.Module):
|
|||
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
# self.proj_out.bias.data = self.NIN_3.b.data
|
||||
#
|
||||
# self.norm.weight.data = self.GroupNorm_0.weight.data
|
||||
# self.norm.bias.data = self.GroupNorm_0.bias.data
|
||||
#
|
||||
# def forward(self, x):
|
||||
# if not self.is_weight_set:
|
||||
# self.set_weights()
|
||||
|
@ -261,6 +270,7 @@ class AttentionBlock(nn.Module):
|
|||
#
|
||||
# return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# TODO(Patrick) - this can and should be removed
|
||||
def zero_module(module):
|
||||
"""
|
||||
|
|
|
@ -30,9 +30,9 @@ from tqdm import tqdm
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
from .attention2d import AttentionBlock
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
|
||||
# h = self.down[i_level].attn_2[i_block](h)
|
||||
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
|
||||
# h = self.down[i_level].attn_2[i_block](h)
|
||||
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
# print("Result", (h - h_2).abs().sum())
|
||||
# print("Result", (h - h_2).abs().sum())
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
|
|
@ -3,9 +3,9 @@ from numpy import pad
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
from .attention2d import LinearAttention
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
|
|
@ -16,18 +16,18 @@
|
|||
# helpers functions
|
||||
|
||||
import functools
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .attention2d import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
|
@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
nn.init.zeros_(modules[-1].bias)
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
|
||||
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
||||
|
||||
if progressive == "output_skip":
|
||||
|
|
|
@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105])
|
||||
expected_slice = torch.tensor(
|
||||
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
|
|
Loading…
Reference in New Issue