[IPEX] Fix torch.Generator hijack
This commit is contained in:
parent
b00b429477
commit
73786c047f
|
@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
|
||||||
return torch.reshape(result, (*N, L, Ev))
|
return torch.reshape(result, (*N, L, Ev))
|
||||||
|
|
||||||
|
|
||||||
|
def is_xpu_device(device: str | torch.device = None):
|
||||||
|
if device is None:
|
||||||
|
return False
|
||||||
|
if isinstance(device, str):
|
||||||
|
return device.startswith("xpu")
|
||||||
|
return device.type == "xpu"
|
||||||
|
|
||||||
|
|
||||||
if has_xpu:
|
if has_xpu:
|
||||||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
|
try:
|
||||||
|
# torch.Generator supports "xpu" device since 2.1
|
||||||
|
torch.Generator("xpu")
|
||||||
|
except:
|
||||||
|
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1)
|
||||||
CondFunc('torch.Generator',
|
CondFunc('torch.Generator',
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||||
lambda orig_func, device=None: device is not None and device.type == "xpu")
|
lambda orig_func, device=None: is_xpu_device(device))
|
||||||
|
|
||||||
# W/A for some OPs that could not handle different input dtypes
|
# W/A for some OPs that could not handle different input dtypes
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
CondFunc('torch.nn.functional.layer_norm',
|
||||||
|
|
Loading…
Reference in New Issue