Merge pull request #14559 from Nuullll/ipex-sdpa-fix

[IPEX] Fix SDPA attn_mask dtype
This commit is contained in:
AUTOMATIC1111 2024-01-06 13:14:18 +03:00 committed by GitHub
commit b00b429477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length