Improve Transformers support (#2970)

* Much better support

* add gpt neox

* bump transformers version

* bump version
This commit is contained in:
Cyril Vallez 2025-02-18 19:04:34 +01:00 committed by GitHub
parent 5543fdc765
commit a7448661f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 111 additions and 79 deletions

View File

@ -1,3 +1,3 @@
transformers==4.48.2
transformers==4.49
huggingface-hub==0.28.1
hf-transfer==0.1.9

View File

@ -346,7 +346,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors

View File

@ -158,7 +158,7 @@ tqdm==4.67.1
# via
# huggingface-hub
# transformers
transformers==4.48.2
transformers==4.49
# via text-generation-server (pyproject.toml)
typer==0.15.1
# via text-generation-server (pyproject.toml)

View File

@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors

View File

@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors

View File

@ -6,17 +6,18 @@ from compressed_tensors.compressors.model_compressors.model_compressor import (
)
from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError
import torch
import enum
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
from loguru import logger
import torch
import transformers
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from huggingface_hub import hf_hub_download, HfApi
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
@ -736,7 +737,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
@ -857,6 +858,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
return CausalLM(
model_id=model_id,
@ -1054,6 +1064,15 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
@ -1467,42 +1486,37 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
# Fast transformers if available
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
model_class = None
# If the model is already in the library
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)
elif (
trust_remote_code
and auto_map is not None
and "AutoModelForCausalLM" in auto_map.keys()
):
model_class = get_class_from_dynamic_module(
config_dict["auto_map"]["AutoModelForCausalLM"], model_id
)
# This means the model is ForCausalLM
if model_class is not None:
if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible():
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
else:
return CausalLM.fallback(
model_id,
revision,
@ -1511,15 +1525,25 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
# Not supported at this point
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
# This means it is a ForSeq2SeqLM model
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or (
trust_remote_code
and auto_map is not None
and "AutoModelForSeq2SeqLM" in auto_map.keys()
):
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")

View File

@ -81,6 +81,15 @@ def tgi_flash_attention_forward(
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
# to internal constraints it was not (yet?) possible to circumvent
REPLICATED_ATTENTION_MODELS = [
"olmo2",
"phi3",
]
class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
@ -119,6 +128,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
@ -130,6 +140,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
tp_plan="auto" if world_size > 1 else None,
)
torch.distributed.barrier(group=self.process_group)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
@ -143,15 +155,19 @@ class TransformersFlashCausalLM(FlashCausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_heads = model.config.num_attention_heads
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads
# Skip it for models in the exception list
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
self.num_heads = self.num_heads // self.process_group.size()
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
@ -186,7 +202,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
torch.tensor(1.0, device=device),
)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
@ -204,6 +219,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward
torch.distributed.barrier(group=self.process_group)
@classmethod
def fallback(
cls,
@ -237,11 +254,16 @@ class TransformersFlashCausalLM(FlashCausalLM):
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
):
hidden_states = self.model.model.forward(
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
logits_to_keep=logits_to_keep,
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
@ -251,20 +273,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states)
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
).logits.squeeze(dim=0)
return logits, None