Add cross-attention optimization from InvokeAI
* Add cross-attention optimization from InvokeAI (~30% speed improvement on MPS) * Add command line option for it * Make it default when CUDA is unavailable
This commit is contained in:
parent
f7e86aa420
commit
c0484f1b98
|
@ -30,8 +30,11 @@ def apply_optimizations():
|
|||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization.")
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import sys
|
||||
import traceback
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
@ -116,6 +117,84 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
|
||||
return self.to_out(r2)
|
||||
|
||||
# -- From https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py (with hypernetworks support added) --
|
||||
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op(q, k, v):
|
||||
if q.device.type == 'mps':
|
||||
if mem_total_gb >= 32:
|
||||
return einsum_op_mps_v1(q, k, v)
|
||||
return einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is not None:
|
||||
k = self.to_k(hypernetwork_layers[0](context)) * self.scale
|
||||
v = self.to_v(hypernetwork_layers[1](context))
|
||||
else:
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py --
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
|
|
|
@ -50,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
|
|||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
|
|
Loading…
Reference in New Issue