add xformers attention
This commit is contained in:
parent
2995107fa2
commit
f174fb2922
|
@ -1,7 +1,9 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
import xformers.ops
|
||||||
|
import functorch
|
||||||
|
xformers._is_functorch_available=True
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def _maybe_init(self, x):
|
||||||
|
"""
|
||||||
|
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
|
||||||
|
: B, Head, Length
|
||||||
|
"""
|
||||||
|
if self.attention_op is not None:
|
||||||
|
return
|
||||||
|
_, M, K = x.shape
|
||||||
|
try:
|
||||||
|
self.attention_op = xformers.ops.AttentionOpDispatch(
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
k=K,
|
||||||
|
attn_bias_type=type(None),
|
||||||
|
has_dropout=False,
|
||||||
|
kv_len=M,
|
||||||
|
q_len=M,
|
||||||
|
).op
|
||||||
|
except NotImplementedError as err:
|
||||||
|
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
|
||||||
|
|
||||||
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k_in = self.to_k(context)
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
self._maybe_init(q)
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
|
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
|
Loading…
Reference in New Issue