From abfa4ad8bc995dcaf832c07a7cf75b6e295a8ca9 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 8 May 2023 18:16:01 -0400 Subject: [PATCH] Use fixed size for sub-quadratic chunking on MPS Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low. --- modules/sd_hijack_optimizations.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0e810eec8..b3e712707 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import psutil +import platform import torch from torch import einsum @@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: - chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + if q.device.type == 'mps': + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: