Use other MPS optimization for large q.shape[0] * q.shape[1]
Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048). Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution.
This commit is contained in:
parent
685f9631b5
commit
35b1775b32
|
@ -127,7 +127,7 @@ def check_for_psutil():
|
||||||
|
|
||||||
invokeAI_mps_available = check_for_psutil()
|
invokeAI_mps_available = check_for_psutil()
|
||||||
|
|
||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
if invokeAI_mps_available:
|
if invokeAI_mps_available:
|
||||||
import psutil
|
import psutil
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_mps_v1(q, k, v):
|
def einsum_op_mps_v1(q, k, v):
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
if slice_size % 4096 == 0:
|
||||||
|
slice_size -= 1
|
||||||
return einsum_op_slice_1(q, k, v, slice_size)
|
return einsum_op_slice_1(q, k, v, slice_size)
|
||||||
|
|
||||||
def einsum_op_mps_v2(q, k, v):
|
def einsum_op_mps_v2(q, k, v):
|
||||||
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
return einsum_op_slice_0(q, k, v, 1)
|
return einsum_op_slice_0(q, k, v, 1)
|
||||||
|
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
|
||||||
return einsum_op_cuda(q, k, v)
|
return einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
if q.device.type == 'mps':
|
if q.device.type == 'mps':
|
||||||
if mem_total_gb >= 32:
|
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||||
return einsum_op_mps_v1(q, k, v)
|
return einsum_op_mps_v1(q, k, v)
|
||||||
return einsum_op_mps_v2(q, k, v)
|
return einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue