feat: natively support Granite models (#2682)
* feat: natively support Granite models * Update doc
This commit is contained in:
parent
f58eb70ebf
commit
03c9388bf7
|
@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||
|
|
|
@ -67,7 +67,9 @@ mkShell {
|
|||
[
|
||||
cuda_cccl
|
||||
cuda_cudart
|
||||
cuda_nvrtc
|
||||
cuda_nvtx
|
||||
cuda_profiler_api
|
||||
cudnn
|
||||
libcublas
|
||||
libcusolver
|
||||
|
|
|
@ -150,6 +150,7 @@ pub enum Config {
|
|||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Granite,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -195,6 +195,11 @@ class ModelType(enum.Enum):
|
|||
"name": "Phi 3",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
GRANITE = {
|
||||
"type": "granite",
|
||||
"name": "Granite",
|
||||
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||
}
|
||||
GEMMA = {
|
||||
"type": "gemma",
|
||||
"name": "Gemma",
|
||||
|
@ -862,7 +867,12 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
elif (
|
||||
model_type == LLAMA
|
||||
or model_type == BAICHUAN
|
||||
or model_type == PHI3
|
||||
or model_type == GRANITE
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
|
@ -876,7 +886,9 @@ def get_model(
|
|||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
|
|
|
@ -156,7 +156,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
self.softmax_scale = getattr(
|
||||
config, "attention_multiplier", self.head_size**-0.5
|
||||
)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
|
@ -180,7 +183,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
|
@ -436,6 +439,11 @@ class FlashLlamaLayer(nn.Module):
|
|||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
# Used in Granite
|
||||
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||
# but this would mean modifying the lora code
|
||||
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -466,13 +474,16 @@ class FlashLlamaLayer(nn.Module):
|
|||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
if self.residual_multiplier is not None:
|
||||
attn_output *= self.residual_multiplier
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
||||
if self.residual_multiplier is not None:
|
||||
mlp_output *= self.residual_multiplier
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
@ -624,6 +635,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
# Used in Granite
|
||||
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||
if embedding_multiplier is not None:
|
||||
self.embed_tokens.weight.data *= embedding_multiplier
|
||||
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
|
@ -631,6 +647,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
weights=weights,
|
||||
)
|
||||
|
||||
# Used in Granite
|
||||
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||
try:
|
||||
# Scale the weights directly
|
||||
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||
self.logits_scaled = True
|
||||
except Exception:
|
||||
self.logits_scaled = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
@ -664,4 +690,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# Used in Granite
|
||||
if not self.logits_scaled:
|
||||
logits /= self.logits_scaling
|
||||
if speculative_logits is not None:
|
||||
speculative_logits /= self.logits_scaling
|
||||
|
||||
return logits, speculative_logits
|
||||
|
|
Loading…
Reference in New Issue