Merge pull request #15804 from huchenlei/rearrange_fix
[Performance 2/6] Replace einops.rearrange with torch native ops
This commit is contained in:
commit
de7f5cdc62
|
@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
def _reshape(t):
|
||||||
|
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
|
||||||
|
Using torch native operations to avoid overhead as this function is
|
||||||
|
called frequently. (70 times/it for SDXL)
|
||||||
|
"""
|
||||||
|
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
|
||||||
|
d = t.shape[2] // h # Determine the depth per head
|
||||||
|
return t.reshape(b, n, h, d)
|
||||||
|
|
||||||
|
q = _reshape(q_in)
|
||||||
|
k = _reshape(k_in)
|
||||||
|
v = _reshape(v_in)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
|
@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
|
|
||||||
out = out.to(dtype)
|
out = out.to(dtype)
|
||||||
|
|
||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
|
b, n, h, d = out.shape
|
||||||
|
out = out.reshape(b, n, h * d)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue