Complete cross attention update
This commit is contained in:
parent
c84e333622
commit
3b1b1444d4
|
@ -11,7 +11,7 @@ from modules.shared import opts, device, cmd_opts
|
|||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
import ldm.modules.attention
|
||||
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
|
@ -100,6 +100,76 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
|
||||
return self.to_out(r2)
|
||||
|
||||
def nonlinearity_hijack(x):
|
||||
# swish
|
||||
t = torch.sigmoid(x)
|
||||
x *= t
|
||||
del t
|
||||
|
||||
return x
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
ids_lookup = {}
|
||||
|
@ -175,6 +245,8 @@ class StableDiffusionModelHijack:
|
|||
|
||||
if cmd_opts.opt_split_attention:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
|
Loading…
Reference in New Issue