fix cpu and xpu issue (#2116)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9e2fdf57c0
commit
e563983d90
|
@ -768,7 +768,10 @@ class FlashCausalLM(Model):
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
x = BLOCK_SIZE // element_size
|
if SYSTEM == "ipex" and device.type == "xpu":
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
|
|
|
@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue