Update attention.py
This commit is contained in:
parent
6304662726
commit
650c887c69
|
@ -174,23 +174,30 @@ class CrossAttention(nn.Module):
|
|||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
# valid values for steps = 2,4,8,16,32,64
|
||||
# higher steps is slower but less memory usage
|
||||
# at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920
|
||||
# speed seems to be impacted more on 30x series cards
|
||||
steps = 16
|
||||
slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue