enable HuggingFaceM4/idefics-9b in intel gpu (#2338)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-08-01 17:08:36 +08:00 committed by GitHub
parent 7451041ecd
commit 9ab9937414
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 1 deletions

View File

@ -351,7 +351,19 @@ class IdeficsRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
residual is not None,
)
return out
elif hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states

View File

@ -20,6 +20,8 @@ from text_generation_server.utils import (
) )
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.import_utils import SYSTEM
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(IdeficsCausalLM):
def __init__( def __init__(
@ -37,6 +39,14 @@ class IDEFICSSharded(IdeficsCausalLM):
# 9b seems to work correctly enough in float16, but 80b seems # 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16. # to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype