feat: enable pytorch xpu support for non-attention models (#2561)
XPU backend is available natively (without IPEX) in pytorch starting from pytorch 2.4. This commit extends TGI to cover the case when user has XPU support thru pytorch 2.4, but does not have IPEX installed. Models which don't require attention can work. For attention required models more work is needed to provide attention implementation. Tested with the following models: * teknium/OpenHermes-2.5-Mistral-7B * bigscience/bloom-560m * google/gemma-7b * google/flan-t5-xxl Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
parent
7a82ddcbd0
commit
58848cb471
|
@ -517,11 +517,10 @@ class CausalLM(Model):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif SYSTEM == "ipex":
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
# Float16 doesn't exist on target.
|
# Float16 doesn't exist on target.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
@ -593,8 +592,14 @@ class CausalLM(Model):
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
device_count = 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
device_count = torch.xpu.device_count()
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
|
@ -616,18 +621,17 @@ class CausalLM(Model):
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=(
|
||||||
"auto"
|
"auto"
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
if device_count > 1
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
torch.cuda.is_available()
|
device_count == 1
|
||||||
and torch.cuda.device_count() == 1
|
|
||||||
and quantize != "bitsandbytes"
|
and quantize != "bitsandbytes"
|
||||||
):
|
):
|
||||||
model = model.cuda()
|
model = model.to(device)
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
|
|
@ -558,11 +558,10 @@ class Seq2SeqLM(Model):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif SYSTEM == "ipex":
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
# Float16 doesn't exist on target.
|
# Float16 doesn't exist on target.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
@ -630,8 +629,14 @@ class Seq2SeqLM(Model):
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
device_count = 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
device_count = torch.xpu.device_count()
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
|
@ -646,14 +651,14 @@ class Seq2SeqLM(Model):
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=(
|
||||||
"auto"
|
"auto"
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
if device_count > 1
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if device_count == 1:
|
||||||
model = model.cuda()
|
model = model.to(device)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -66,6 +66,11 @@ elif is_ipex_available():
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = get_cpu_free_memory
|
get_free_memory = get_cpu_free_memory
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
SYSTEM = "xpu"
|
||||||
|
empty_cache = torch.xpu.empty_cache
|
||||||
|
synchronize = torch.xpu.synchronize
|
||||||
|
get_free_memory = get_xpu_free_memory
|
||||||
else:
|
else:
|
||||||
SYSTEM = "cpu"
|
SYSTEM = "cpu"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue