diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfcd..da1b76e1c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,9 @@ import math import torch from torch import einsum - +import xformers.ops +import functorch +xformers._is_functorch_available=True from ldm.util import default from einops import rearrange @@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + self._maybe_init(q) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_)