Complete cross attention update
This commit is contained in:
parent
c84e333622
commit
3b1b1444d4
|
@ -11,7 +11,7 @@ from modules.shared import opts, device, cmd_opts
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
@ -100,6 +100,76 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def nonlinearity_hijack(x):
|
||||||
|
# swish
|
||||||
|
t = torch.sigmoid(x)
|
||||||
|
x *= t
|
||||||
|
del t
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def cross_attention_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_)
|
||||||
|
k1 = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q1.shape
|
||||||
|
|
||||||
|
q2 = q1.reshape(b, c, h*w)
|
||||||
|
del q1
|
||||||
|
|
||||||
|
q = q2.permute(0, 2, 1) # b,hw,c
|
||||||
|
del q2
|
||||||
|
|
||||||
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||||
|
del k1
|
||||||
|
|
||||||
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
mem_required = tensor_size * 2.5
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
|
||||||
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w2 = w1 * (int(c)**(-0.5))
|
||||||
|
del w1
|
||||||
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||||
|
del w2
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v1 = v.reshape(b, c, h*w)
|
||||||
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
del w3
|
||||||
|
|
||||||
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
del v1, w4
|
||||||
|
|
||||||
|
h2 = h_.reshape(b, c, h, w)
|
||||||
|
del h_
|
||||||
|
|
||||||
|
h3 = self.proj_out(h2)
|
||||||
|
del h2
|
||||||
|
|
||||||
|
h3 += x
|
||||||
|
|
||||||
|
return h3
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
ids_lookup = {}
|
ids_lookup = {}
|
||||||
|
@ -175,6 +245,8 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention:
|
if cmd_opts.opt_split_attention:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
elif cmd_opts.opt_split_attention_v1:
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue