feat: natively support Granite models (#2682)
* feat: natively support Granite models * Update doc
This commit is contained in:
parent
f58eb70ebf
commit
03c9388bf7
|
@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
|
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||||
|
|
|
@ -67,7 +67,9 @@ mkShell {
|
||||||
[
|
[
|
||||||
cuda_cccl
|
cuda_cccl
|
||||||
cuda_cudart
|
cuda_cudart
|
||||||
|
cuda_nvrtc
|
||||||
cuda_nvtx
|
cuda_nvtx
|
||||||
|
cuda_profiler_api
|
||||||
cudnn
|
cudnn
|
||||||
libcublas
|
libcublas
|
||||||
libcusolver
|
libcusolver
|
||||||
|
|
|
@ -150,6 +150,7 @@ pub enum Config {
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
Granite,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
Bloom,
|
Bloom,
|
||||||
Mpt,
|
Mpt,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -195,6 +195,11 @@ class ModelType(enum.Enum):
|
||||||
"name": "Phi 3",
|
"name": "Phi 3",
|
||||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
|
GRANITE = {
|
||||||
|
"type": "granite",
|
||||||
|
"name": "Granite",
|
||||||
|
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
}
|
||||||
GEMMA = {
|
GEMMA = {
|
||||||
"type": "gemma",
|
"type": "gemma",
|
||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
|
@ -862,7 +867,12 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif (
|
||||||
|
model_type == LLAMA
|
||||||
|
or model_type == BAICHUAN
|
||||||
|
or model_type == PHI3
|
||||||
|
or model_type == GRANITE
|
||||||
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -876,7 +886,9 @@ def get_model(
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -156,7 +156,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
# `config.attention_multiplier` is used in Granite
|
||||||
|
self.softmax_scale = getattr(
|
||||||
|
config, "attention_multiplier", self.head_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -180,7 +183,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=getattr(config, "attention_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
@ -436,6 +439,11 @@ class FlashLlamaLayer(nn.Module):
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||||
|
# but this would mean modifying the lora code
|
||||||
|
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -466,13 +474,16 @@ class FlashLlamaLayer(nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
# faster post attention rms norm
|
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
@ -624,6 +635,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
|
if embedding_multiplier is not None:
|
||||||
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -631,6 +647,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||||
|
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||||
|
try:
|
||||||
|
# Scale the weights directly
|
||||||
|
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||||
|
self.logits_scaled = True
|
||||||
|
except Exception:
|
||||||
|
self.logits_scaled = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
@ -664,4 +690,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
if not self.logits_scaled:
|
||||||
|
logits /= self.logits_scaling
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits /= self.logits_scaling
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
Loading…
Reference in New Issue