Update attention.py

This commit is contained in:
Christopher Gammage 2022-09-25 16:26:55 -07:00
parent 6304662726
commit 650c887c69
1 changed files with 19 additions and 12 deletions

View File

@ -174,23 +174,30 @@ class CrossAttention(nn.Module):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# valid values for steps = 2,4,8,16,32,64
# higher steps is slower but less memory usage
# at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920
# speed seems to be impacted more on 30x series cards
steps = 16
slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
return self.to_out(r2)
class BasicTransformerBlock(nn.Module):