2022-10-02 06:03:39 -06:00
|
|
|
import math
|
2022-10-08 08:02:18 -06:00
|
|
|
import sys
|
|
|
|
import traceback
|
2022-10-10 21:55:48 -06:00
|
|
|
import importlib
|
2022-10-08 08:02:18 -06:00
|
|
|
|
2022-10-02 06:03:39 -06:00
|
|
|
import torch
|
|
|
|
from torch import einsum
|
2022-10-08 07:33:39 -06:00
|
|
|
|
2022-10-02 06:03:39 -06:00
|
|
|
from ldm.util import default
|
|
|
|
from einops import rearrange
|
|
|
|
|
2022-10-11 05:53:02 -06:00
|
|
|
from modules import shared
|
2022-10-11 06:51:22 -06:00
|
|
|
from modules.hypernetworks import hypernetwork
|
2022-10-11 02:09:51 -06:00
|
|
|
|
2022-10-07 01:17:52 -06:00
|
|
|
|
2022-10-08 10:25:10 -06:00
|
|
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
2022-10-08 08:02:18 -06:00
|
|
|
try:
|
|
|
|
import xformers.ops
|
|
|
|
shared.xformers_available = True
|
|
|
|
except Exception:
|
|
|
|
print("Cannot import xformers", file=sys.stderr)
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
|
2022-10-02 06:03:39 -06:00
|
|
|
|
|
|
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
|
|
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|
|
|
h = self.heads
|
|
|
|
|
2022-10-07 23:47:02 -06:00
|
|
|
q_in = self.to_q(x)
|
2022-10-02 06:03:39 -06:00
|
|
|
context = default(context, x)
|
2022-10-07 23:47:02 -06:00
|
|
|
|
2022-10-11 02:09:51 -06:00
|
|
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
|
|
|
k_in = self.to_k(context_k)
|
|
|
|
v_in = self.to_v(context_v)
|
|
|
|
del context, context_k, context_v, x
|
2022-10-02 06:03:39 -06:00
|
|
|
|
2022-10-07 23:47:02 -06:00
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
|
|
del q_in, k_in, v_in
|
2022-10-02 06:03:39 -06:00
|
|
|
|
|
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
|
|
for i in range(0, q.shape[0], 2):
|
|
|
|
end = i + 2
|
|
|
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
|
|
s1 *= self.scale
|
|
|
|
|
|
|
|
s2 = s1.softmax(dim=-1)
|
|
|
|
del s1
|
|
|
|
|
|
|
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
|
|
del s2
|
2022-10-07 23:47:02 -06:00
|
|
|
del q, k, v
|
2022-10-02 06:03:39 -06:00
|
|
|
|
|
|
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
|
|
del r1
|
|
|
|
|
|
|
|
return self.to_out(r2)
|
|
|
|
|
|
|
|
|
2022-10-11 02:09:51 -06:00
|
|
|
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
2022-10-02 06:03:39 -06:00
|
|
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
|
|
|
h = self.heads
|
|
|
|
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
context = default(context, x)
|
2022-10-07 01:17:52 -06:00
|
|
|
|
2022-10-11 02:09:51 -06:00
|
|
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
|
|
|
k_in = self.to_k(context_k)
|
|
|
|
v_in = self.to_v(context_v)
|
2022-10-07 01:17:52 -06:00
|
|
|
|
|
|
|
k_in *= self.scale
|
|
|
|
|
2022-10-02 06:03:39 -06:00
|
|
|
del context, x
|
|
|
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
|
|
del q_in, k_in, v_in
|
|
|
|
|
|
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
|
|
|
|
|
|
stats = torch.cuda.memory_stats(q.device)
|
|
|
|
mem_active = stats['active_bytes.all.current']
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
|
|
|
|
gb = 1024 ** 3
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
|
|
modifier = 3 if q.element_size() == 2 else 2.5
|
|
|
|
mem_required = tensor_size * modifier
|
|
|
|
steps = 1
|
|
|
|
|
|
|
|
if mem_required > mem_free_total:
|
|
|
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
|
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
|
|
|
|
|
|
if steps > 64:
|
|
|
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
|
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
|
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
|
|
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
|
|
for i in range(0, q.shape[1], slice_size):
|
|
|
|
end = i + slice_size
|
|
|
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
|
|
|
|
|
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
|
|
del s1
|
|
|
|
|
|
|
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
|
|
del s2
|
|
|
|
|
|
|
|
del q, k, v
|
|
|
|
|
|
|
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
|
|
del r1
|
|
|
|
|
|
|
|
return self.to_out(r2)
|
|
|
|
|
2022-10-10 20:48:54 -06:00
|
|
|
|
2022-10-10 21:55:48 -06:00
|
|
|
def check_for_psutil():
|
|
|
|
try:
|
|
|
|
spec = importlib.util.find_spec('psutil')
|
|
|
|
return spec is not None
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
invokeAI_mps_available = check_for_psutil()
|
|
|
|
|
|
|
|
# -- Taken from https://github.com/invoke-ai/InvokeAI --
|
|
|
|
if invokeAI_mps_available:
|
|
|
|
import psutil
|
|
|
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
2022-10-10 20:48:54 -06:00
|
|
|
|
|
|
|
def einsum_op_compvis(q, k, v):
|
|
|
|
s = einsum('b i d, b j d -> b i j', q, k)
|
|
|
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
|
|
|
return einsum('b i j, b j d -> b i d', s, v)
|
|
|
|
|
|
|
|
def einsum_op_slice_0(q, k, v, slice_size):
|
|
|
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
|
|
for i in range(0, q.shape[0], slice_size):
|
|
|
|
end = i + slice_size
|
|
|
|
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
|
|
|
return r
|
|
|
|
|
|
|
|
def einsum_op_slice_1(q, k, v, slice_size):
|
|
|
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
|
|
for i in range(0, q.shape[1], slice_size):
|
|
|
|
end = i + slice_size
|
|
|
|
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
|
|
|
return r
|
|
|
|
|
|
|
|
def einsum_op_mps_v1(q, k, v):
|
|
|
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
|
|
return einsum_op_compvis(q, k, v)
|
|
|
|
else:
|
|
|
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
|
|
return einsum_op_slice_1(q, k, v, slice_size)
|
|
|
|
|
|
|
|
def einsum_op_mps_v2(q, k, v):
|
|
|
|
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
|
|
return einsum_op_compvis(q, k, v)
|
|
|
|
else:
|
|
|
|
return einsum_op_slice_0(q, k, v, 1)
|
|
|
|
|
|
|
|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
|
|
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
|
|
if size_mb <= max_tensor_mb:
|
|
|
|
return einsum_op_compvis(q, k, v)
|
|
|
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
|
|
if div <= q.shape[0]:
|
|
|
|
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
|
|
|
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
|
|
|
|
2022-10-11 01:32:11 -06:00
|
|
|
def einsum_op_cuda(q, k, v):
|
|
|
|
stats = torch.cuda.memory_stats(q.device)
|
|
|
|
mem_active = stats['active_bytes.all.current']
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
# Divide factor of safety as there's copying and fragmentation
|
2022-10-18 17:28:28 -06:00
|
|
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
2022-10-11 01:32:11 -06:00
|
|
|
|
2022-10-10 20:48:54 -06:00
|
|
|
def einsum_op(q, k, v):
|
2022-10-11 01:32:11 -06:00
|
|
|
if q.device.type == 'cuda':
|
|
|
|
return einsum_op_cuda(q, k, v)
|
|
|
|
|
2022-10-10 20:48:54 -06:00
|
|
|
if q.device.type == 'mps':
|
|
|
|
if mem_total_gb >= 32:
|
|
|
|
return einsum_op_mps_v1(q, k, v)
|
|
|
|
return einsum_op_mps_v2(q, k, v)
|
|
|
|
|
|
|
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
|
|
# Tested on i7 with 8MB L3 cache.
|
|
|
|
return einsum_op_tensor_mem(q, k, v, 32)
|
|
|
|
|
|
|
|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|
|
|
h = self.heads
|
|
|
|
|
|
|
|
q = self.to_q(x)
|
|
|
|
context = default(context, x)
|
|
|
|
|
2022-10-11 03:13:17 -06:00
|
|
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
|
|
|
k = self.to_k(context_k) * self.scale
|
|
|
|
v = self.to_v(context_v)
|
|
|
|
del context, context_k, context_v, x
|
2022-10-10 20:48:54 -06:00
|
|
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
|
|
r = einsum_op(q, k, v)
|
|
|
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
|
|
|
|
2022-10-10 21:55:48 -06:00
|
|
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
2022-10-10 20:48:54 -06:00
|
|
|
|
2022-10-06 20:21:49 -06:00
|
|
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
|
|
|
h = self.heads
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
context = default(context, x)
|
2022-10-11 02:09:51 -06:00
|
|
|
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
|
|
|
k_in = self.to_k(context_k)
|
|
|
|
v_in = self.to_v(context_v)
|
|
|
|
|
2022-10-07 19:09:18 -06:00
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
2022-10-06 20:21:49 -06:00
|
|
|
del q_in, k_in, v_in
|
2022-10-07 19:09:18 -06:00
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
2022-10-06 20:21:49 -06:00
|
|
|
|
2022-10-07 19:09:18 -06:00
|
|
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
2022-10-06 20:21:49 -06:00
|
|
|
return self.to_out(out)
|
|
|
|
|
2022-10-02 06:03:39 -06:00
|
|
|
def cross_attention_attnblock_forward(self, x):
|
|
|
|
h_ = x
|
|
|
|
h_ = self.norm(h_)
|
|
|
|
q1 = self.q(h_)
|
|
|
|
k1 = self.k(h_)
|
|
|
|
v = self.v(h_)
|
|
|
|
|
|
|
|
# compute attention
|
|
|
|
b, c, h, w = q1.shape
|
|
|
|
|
|
|
|
q2 = q1.reshape(b, c, h*w)
|
|
|
|
del q1
|
|
|
|
|
|
|
|
q = q2.permute(0, 2, 1) # b,hw,c
|
|
|
|
del q2
|
|
|
|
|
|
|
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
|
|
|
del k1
|
|
|
|
|
|
|
|
h_ = torch.zeros_like(k, device=q.device)
|
|
|
|
|
|
|
|
stats = torch.cuda.memory_stats(q.device)
|
|
|
|
mem_active = stats['active_bytes.all.current']
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
|
|
mem_required = tensor_size * 2.5
|
|
|
|
steps = 1
|
|
|
|
|
|
|
|
if mem_required > mem_free_total:
|
|
|
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
|
|
for i in range(0, q.shape[1], slice_size):
|
|
|
|
end = i + slice_size
|
|
|
|
|
|
|
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
|
|
w2 = w1 * (int(c)**(-0.5))
|
|
|
|
del w1
|
|
|
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
|
|
|
del w2
|
|
|
|
|
|
|
|
# attend to values
|
|
|
|
v1 = v.reshape(b, c, h*w)
|
|
|
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
|
|
del w3
|
|
|
|
|
|
|
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
|
|
del v1, w4
|
|
|
|
|
|
|
|
h2 = h_.reshape(b, c, h, w)
|
|
|
|
del h_
|
|
|
|
|
|
|
|
h3 = self.proj_out(h2)
|
|
|
|
del h2
|
|
|
|
|
|
|
|
h3 += x
|
|
|
|
|
|
|
|
return h3
|
2022-10-08 02:55:02 -06:00
|
|
|
|
2022-10-17 13:18:59 -06:00
|
|
|
def xformers_attnblock_forward(self, x):
|
|
|
|
try:
|
|
|
|
h_ = x
|
|
|
|
h_ = self.norm(h_)
|
|
|
|
q = self.q(h_)
|
|
|
|
k = self.k(h_)
|
|
|
|
v = self.v(h_)
|
|
|
|
b, c, h, w = q.shape
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
2022-10-17 15:02:50 -06:00
|
|
|
q = q.contiguous()
|
|
|
|
k = k.contiguous()
|
|
|
|
v = v.contiguous()
|
2022-10-17 13:18:59 -06:00
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v)
|
|
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
|
|
out = self.proj_out(out)
|
|
|
|
return x + out
|
|
|
|
except NotImplementedError:
|
|
|
|
return cross_attention_attnblock_forward(self, x)
|