Use fixed size for sub-quadratic chunking on MPS
Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.
This commit is contained in:
parent
3163d1269a
commit
abfa4ad8bc
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
import psutil
|
||||
import platform
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
if chunk_threshold is None:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||
if q.device.type == 'mps':
|
||||
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||
else:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||
elif chunk_threshold == 0:
|
||||
chunk_threshold_bytes = None
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue