fix cpu and xpu issue (#2116)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-06-25 22:47:06 +08:00 committed by GitHub
parent 9e2fdf57c0
commit e563983d90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 16 additions and 7 deletions

View File

@ -768,7 +768,10 @@ class FlashCausalLM(Model):
empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [

View File

@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")

View File

@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")

View File

@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM):
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")