Fix device id
This commit is contained in:
parent
e4b4a9c4ac
commit
f586f4973a
|
@ -33,7 +33,7 @@ has_xpu = check_for_xpu()
|
|||
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
||||
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
||||
# which is the best trade-off between VRAM usage and performance.
|
||||
ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||
ARC_SINGLE_ALLOCATION_LIMIT = {}
|
||||
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||
def torch_xpu_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
||||
|
@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention(
|
|||
Ev = value.size(-1) # Embedding dimension of the value
|
||||
|
||||
total_batch_size = torch.numel(torch.empty(N))
|
||||
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size()))
|
||||
device_id = query.device.index
|
||||
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
||||
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
||||
|
||||
if total_batch_size <= batch_size_limit:
|
||||
return orig_sdp_attn_func(
|
||||
|
|
Loading…
Reference in New Issue