From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:09:56 +0800 Subject: [PATCH] [IPEX] Fix torch.Generator hijack --- modules/xpu_specific.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b2..1137891a6 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm',