refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132)
* refine get xpu free memory Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable qwen2 in xpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable gemma/gemma2/phi in intel platform Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9d0ca503a8
commit
5da4cfab1c
|
@ -14,6 +14,7 @@ def attention(
|
|||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
|
@ -28,7 +29,7 @@ def attention(
|
|||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
|
|||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
|
|||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
|
|||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral):
|
|||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from loguru import logger
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
|
||||
def is_ipex_available():
|
||||
|
@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction):
|
|||
def get_xpu_free_memory(device, memory_fraction):
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
device_id = device.index
|
||||
query = f"xpu-smi dump -d {device_id} -m 18 -n 1"
|
||||
output = subprocess.check_output(query.split()).decode("utf-8").split("\n")
|
||||
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024
|
||||
free_memory = int(total_memory * 0.95 - used_memory)
|
||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
|
||||
free_memory = max(
|
||||
0,
|
||||
int(
|
||||
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
|
||||
),
|
||||
)
|
||||
return free_memory
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue