From 9ab9937414cd1c6a3299f322268d1bacb235f7b6 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 1 Aug 2024 17:08:36 +0800 Subject: [PATCH] enable HuggingFaceM4/idefics-9b in intel gpu (#2338) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/idefics_modeling.py | 14 +++++++++++++- server/text_generation_server/models/idefics.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 771bc234..fc6becc4 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -351,7 +351,19 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + residual is not None, + ) + return out + elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index af224254..29929b98 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -20,6 +20,8 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.import_utils import SYSTEM + class IDEFICSSharded(IdeficsCausalLM): def __init__( @@ -37,6 +39,14 @@ class IDEFICSSharded(IdeficsCausalLM): # 9b seems to work correctly enough in float16, but 80b seems # to be really saturating for f16. dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype