diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4c2dc56d4..0269f1f5b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -486,18 +486,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - def _reshape(t): - """rearrange(t, 'b n (h d) -> b n h d', h=h). - Using torch native operations to avoid overhead as this function is - called frequently. (70 times/it for SDXL) - """ - b, n, _ = t.shape # Get the batch size (b) and sequence length (n) - d = t.shape[2] // h # Determine the depth per head - return t.reshape(b, n, h, d) - - q = _reshape(q_in) - k = _reshape(k_in) - v = _reshape(v_in) + q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in @@ -509,7 +498,6 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): out = out.to(dtype) - # out = rearrange(out, 'b n h d -> b n (h d)', h=h) b, n, h, d = out.shape out = out.reshape(b, n, h * d) return self.to_out(out)