some clean up

This commit is contained in:
Patrick von Platen 2022-06-28 23:09:50 +00:00
parent 31d1f3c8c0
commit c482d7bd4f
6 changed files with 5 additions and 46 deletions

View File

@ -15,22 +15,12 @@
# helpers functions
import copy
import math
from pathlib import Path
import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from PIL import Image
from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum())
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample

View File

@ -1,9 +1,8 @@
import torch
from numpy import pad
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import LinearAttention
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
@ -55,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
return output
class old_LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
class Residual(torch.nn.Module):
def __init__(self, fn):
super(Residual, self).__init__()

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample

View File

@ -26,7 +26,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding