some clean up
This commit is contained in:
parent
31d1f3c8c0
commit
c482d7bd4f
|
@ -15,22 +15,12 @@
|
|||
|
||||
# helpers functions
|
||||
|
||||
import copy
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
|
||||
# h = self.down[i_level].attn_2[i_block](h)
|
||||
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
# print("Result", (h - h_2).abs().sum())
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
from numpy import pad
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import LinearAttention
|
||||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
@ -55,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class old_LinearAttention(torch.nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super(LinearAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
q, k, v = (
|
||||
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
|
||||
.permute(1, 0, 2, 3, 4, 5)
|
||||
.reshape(3, b, self.heads, self.dim_head, -1)
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Residual(torch.nn.Module):
|
||||
def __init__(self, fn):
|
||||
super(Residual, self).__init__()
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import torch.nn.functional as F
|
|||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention2d import AttentionBlock
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue