Refactor dead code - Removing all `flash_xxx.py` files. (#2166)

* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
This commit is contained in:
Nicolas Patry 2024-07-05 10:29:56 +02:00 committed by GitHub
parent c6bcadf883
commit fb2f74e2b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 689 additions and 2451 deletions

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "2.1.1-dev0" "version": "2.1.2-dev0"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b) - [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)

View File

@ -1,130 +1,124 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 20, "generated_tokens": 19,
"prefill": [], "prefill": [],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.039886475, "logprob": -0.03665161,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.1430664, "logprob": -0.13549805,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.056488037, "logprob": -0.05819702,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6855469, "logprob": -0.6826172,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.1685791, "logprob": -0.1607666,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.50097656, "logprob": -0.5073242,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017303467, "logprob": -0.016418457,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.3564453, "logprob": -1.3916016,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.017868042, "logprob": -0.020217896,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.0027103424, "logprob": -0.0028133392,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.003156662, "logprob": -0.003145218,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.37304688, "logprob": -0.37060547,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.034576416, "logprob": -0.034851074,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.29418945, "logprob": -0.2878418,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.042877197, "logprob": -0.046051025,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.00028443336, "logprob": -0.00028848648,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.023223877, "logprob": -0.025772095,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.018157959, "logprob": -0.018127441,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00018393993, "logprob": -0.00019824505,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
} }
], ],
"top_tokens": null "top_tokens": null

View File

@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
response.generated_text response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money." == " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}" ), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 19
assert response == response_snapshot assert response == response_snapshot

View File

@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -16,7 +19,10 @@ def default_bloom():
revision = "main" revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors") filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision) download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id) return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_causal_lm(): def default_causal_lm():
return CausalLM("gpt2") return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,13 +1,12 @@
import pytest import pytest
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_santacoder(): def default_santacoder():
return SantaCoder("bigcode/santacoder") return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture @pytest.fixture

View File

@ -20,7 +20,7 @@ def mt0_small_tokenizer():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_seq2seq_lm(): def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small") return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture @pytest.fixture

View File

@ -11,17 +11,26 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.opt import OPTSharded from text_generation_server.models.custom_modeling.neox_modeling import (
from text_generation_server.models.galactica import GalacticaSharded GPTNeoxForCausalLM,
from text_generation_server.models.santacoder import SantaCoder )
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.custom_modeling.phi_modeling import (
from text_generation_server.models.gpt_neox import GPTNeoxSharded PhiConfig,
from text_generation_server.models.phi import Phi PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -41,9 +50,6 @@ __all__ = [
"CausalLM", "CausalLM",
"GalacticaSharded", "GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model", "get_model",
] ]
@ -53,38 +59,65 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.custom_modeling.flash_llama_modeling import (
from text_generation_server.models.flash_neox import FlashNeoXSharded FlashLlamaForCausalLM,
from text_generation_server.models.flash_llama import (
FlashLlama,
) )
from text_generation_server.models.flash_qwen2 import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashQwen2, FlashCohereForCausalLM,
) )
from text_generation_server.models.flash_cohere import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashCohere, FlashGemmaForCausalLM,
) )
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma, FlashGemma2ForCausalLM,
) )
from text_generation_server.models.flash_gemma2 import ( from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashGemma2, FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
) )
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemma, PaliGemmaBatch,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashSantacoderSharded, PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.custom_modeling.llava_next import (
from text_generation_server.models.idefics2 import Idefics2 LlavaNextForConditionalGeneration,
from text_generation_server.models.flash_mistral import FlashMistral )
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 FlashSantacoderForCausalLM,
from text_generation_server.models.flash_dbrx import FlashDbrx )
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -93,21 +126,7 @@ except ImportError as e:
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
__all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
try: try:
@ -148,6 +167,11 @@ class ModelType(enum.Enum):
"name": "Gemma", "name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b", "url": "https://huggingface.co/google/gemma-7b",
} }
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = { GEMMA2 = {
"type": "gemma2", "type": "gemma2",
"name": "Gemma2", "name": "Gemma2",
@ -445,13 +469,16 @@ def get_model(
) )
if model_id.startswith("facebook/galactica"): if model_id.startswith("facebook/galactica"):
return GalacticaSharded( return CausalLM(
model_id, model_id=model_id,
revision, # Yes galactica is just an OPT model.
model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=GalacticaCausalLMBatch,
) )
if ( if (
@ -460,22 +487,26 @@ def get_model(
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return SantaCoder( return CausalLM.fallback(
model_id, model_id=model_id,
revision, revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -483,38 +514,44 @@ def get_model(
) )
if model_type == BLOOM: if model_type == BLOOM:
return BLOOMSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=BloomForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
) )
elif model_type == MPT: elif model_type == MPT:
return MPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=MPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
) )
elif model_type == GPT2: elif model_type == GPT2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
try: try:
return FlashGPT2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -525,7 +562,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -535,25 +572,28 @@ def get_model(
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashNeoXSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
return GPTNeoxSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=GPTNeoxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -564,16 +604,18 @@ def get_model(
elif model_type == PHI: elif model_type == PHI:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashPhi( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -588,9 +630,11 @@ def get_model(
"Legacy phi-msft is not supported with Flash Attention" "Legacy phi-msft is not supported with Flash Attention"
) )
else: else:
return Phi( return CausalLM(
model_id, model_id=model_id,
revision, model_class=PhiForCausalLM,
config_class=PhiConfig,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -599,9 +643,10 @@ def get_model(
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -611,7 +656,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -621,18 +666,22 @@ def get_model(
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -642,18 +691,22 @@ def get_model(
) )
elif model_type == GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -664,18 +717,20 @@ def get_model(
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCohere( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -686,18 +741,23 @@ def get_model(
if model_type == DBRX: if model_type == DBRX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashDbrx( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -711,27 +771,37 @@ def get_model(
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
else: else:
return RW( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -742,18 +812,20 @@ def get_model(
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMistral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -764,18 +836,20 @@ def get_model(
if model_type == MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMixtral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -786,19 +860,22 @@ def get_model(
if model_type == STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashStarcoder2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -809,17 +886,20 @@ def get_model(
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashQwen2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -829,9 +909,10 @@ def get_model(
) )
if model_type == OPT: if model_type == OPT:
return OPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -839,13 +920,20 @@ def get_model(
) )
if model_type == T5: if model_type == T5:
return T5Sharded( return Seq2SeqLM(
model_id, model_id=model_id,
revision, model_class=T5ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION:
@ -861,34 +949,45 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return Idefics2( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma": if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return PaliGemma( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return LlavaNext( return VlmCausalLM(
model_id, model_class=LlavaNextForConditionalGeneration,
revision, model_id=model_id,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -912,7 +1011,7 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -921,7 +1020,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -933,7 +1032,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -942,7 +1041,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -4,22 +4,12 @@ import torch.distributed
from typing import Optional, Type from typing import Optional, Type
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOMSharded(CausalLM): class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch

View File

@ -1,13 +1,25 @@
import torch import torch
import time import time
import torch.distributed
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -478,10 +490,87 @@ class CausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
@dataclass
class CausalLMBatchKeysLast(Batch):
keys_head_dim_last: bool = False
class CausalLM(Model): class CausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
batch_class=CausalLMBatch,
):
self.batch_class = batch_class
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
@ -537,7 +626,12 @@ class CausalLM(Model):
else: else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( self = cls.__new__(
cls,
)
self.batch_class = CausalLMBatch
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -545,15 +639,11 @@ class CausalLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return self.batch_class
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -464,8 +464,9 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights) self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights

View File

@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config = config self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
self.language_model = load_text_model( self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
weights=weights, weights=weights,
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, image_features
) )
hidden_states = self.language_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
@ -21,6 +26,12 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -799,29 +810,110 @@ class FlashCausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model: torch.nn.Module, model_class,
tokenizer: PreTrainedTokenizerBase, revision: Optional[str] = None,
num_layers: int, quantize: Optional[str] = None,
num_kv_heads: int, speculator: Optional[str] = None,
head_size: int, dtype: Optional[torch.dtype] = None,
dtype: torch.dtype, trust_remote_code: bool = False,
device: torch.device, lora_adapter_ids: Optional[list] = [],
rank: int = 0, tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
world_size: int = 1, config_class: PreTrainedTokenizerBase = AutoConfig,
sliding_window: Optional[int] = None, default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
skip_special_tokens: bool = True,
): ):
self.num_layers = num_layers self.process_group, rank, world_size = initialize_torch_distributed()
self.num_kv_heads = num_kv_heads if torch.cuda.is_available():
self.head_size = head_size device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError(f"{model_class} is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
if num_kv_heads is None:
# Final overide for GPT2
num_kv_heads = config.n_head
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
super(FlashCausalLM, self).__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -830,7 +922,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window, sliding_window=config.sliding_window,
) )
@property @property
@ -1578,3 +1670,72 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,75 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashCohere(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,100 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashDbrx(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashDBRX is only available on GPU")
try:
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = DbrxConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,171 +0,0 @@
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,24 +1,7 @@
import torch import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class FlashMistral(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
@property @property
def supports_adapter_loading(self) -> bool: def supports_adapter_loading(self) -> bool:
return True return True
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
# This accounts for VLMs (e.g. LlavaNext, Idefics2) # This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model. # that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"): if hasattr(self.model, "text_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model _model = self.model.text_model
else: else:
_model = self.model _model = self.model
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,31 +0,0 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,111 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if speculator:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
speculator, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
speculator, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,93 +0,0 @@
import math
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashQwen2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -1,91 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
)
config.quantize = quantize
config.speculator = speculator
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,99 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,84 +0,0 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,89 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class GPTNeoxSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,51 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,46 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -1,105 +0,0 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -1,86 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,69 +0,0 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -1,84 +0,0 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,77 +0,0 @@
import torch
import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
with device:
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,11 +1,22 @@
import torch import torch
import torch.distributed
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[Seq2SeqLMBatch]: def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,
input_ids, input_ids,

View File

@ -1,115 +0,0 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class T5Sharded(Seq2SeqLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)

View File

@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_causal_lm import (
from text_generation_server.models.flash_mistral import ( FlashCausalLMBatch,
BaseFlashMistral, FlashCausalLM,
) )
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch return batch
class VlmCausalLM(BaseFlashMistral): class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs)
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward( def forward(
self, self,