From 5da4cfab1c211ff3e2aefbd0358f714970fb8360 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 1 Jul 2024 20:32:54 +0800 Subject: [PATCH] refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132) * refine get xpu free memory Signed-off-by: Wang, Yi A * enable qwen2 in xpu Signed-off-by: Wang, Yi A * enable gemma/gemma2/phi in intel platform Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/ipex.py | 3 ++- server/text_generation_server/models/flash_gemma.py | 8 ++++++++ server/text_generation_server/models/flash_gemma2.py | 8 ++++++++ server/text_generation_server/models/flash_phi.py | 8 ++++++++ server/text_generation_server/models/flash_qwen2.py | 8 ++++++++ server/text_generation_server/utils/import_utils.py | 12 ++++++++---- 6 files changed, 42 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index bfab0119..7f086b68 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -14,6 +14,7 @@ def attention( max_s, softmax_scale, window_size_left=-1, + causal=True, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( @@ -28,7 +29,7 @@ def attention( 0.0, softmax_scale, False, - True, + causal, False, None, ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index aa1ae9ac..7e2b8780 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma is only available on GPU") diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py index 9608113b..86cfc7e2 100644 --- a/server/text_generation_server/models/flash_gemma2.py +++ b/server/text_generation_server/models/flash_gemma2.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 7e108d05..a530d1c3 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashPhi is only available on GPU") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 23528f0b..cd6078f1 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashQwen2 is only available on GPU") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 6d921721..011e0f63 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,6 +1,7 @@ import torch from loguru import logger import subprocess +import os def is_ipex_available(): @@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - query = f"xpu-smi dump -d {device_id} -m 18 -n 1" - output = subprocess.check_output(query.split()).decode("utf-8").split("\n") - used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 - free_memory = int(total_memory * 0.95 - used_memory) + memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + free_memory = max( + 0, + int( + total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) + ), + ) return free_memory