diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd5964..f7687a66c 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ has_xpu = check_for_xpu() # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func(