Fix device id

This commit is contained in:
Nuullll 2023-12-18 19:44:52 +08:00
parent e4b4a9c4ac
commit f586f4973a
1 changed files with 5 additions and 2 deletions

View File

@ -33,7 +33,7 @@ has_xpu = check_for_xpu()
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# which is the best trade-off between VRAM usage and performance.
ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
ARC_SINGLE_ALLOCATION_LIMIT = {}
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
def torch_xpu_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention(
Ev = value.size(-1) # Embedding dimension of the value
total_batch_size = torch.numel(torch.empty(N))
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size()))
device_id = query.device.index
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
if total_batch_size <= batch_size_limit:
return orig_sdp_attn_func(