integrated edits as recommended in the PR #15804

This commit is contained in:
AUTOMATIC1111 2024-06-08 09:05:35 +03:00
parent de7f5cdc62
commit 0769aa318a
1 changed files with 1 additions and 13 deletions

View File

@ -486,18 +486,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
def _reshape(t): q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
Using torch native operations to avoid overhead as this function is
called frequently. (70 times/it for SDXL)
"""
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
d = t.shape[2] // h # Determine the depth per head
return t.reshape(b, n, h, d)
q = _reshape(q_in)
k = _reshape(k_in)
v = _reshape(v_in)
del q_in, k_in, v_in del q_in, k_in, v_in
@ -509,7 +498,6 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
out = out.to(dtype) out = out.to(dtype)
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
b, n, h, d = out.shape b, n, h, d = out.shape
out = out.reshape(b, n, h * d) out = out.reshape(b, n, h * d)
return self.to_out(out) return self.to_out(out)