Refactor dead code - Removing all `flash_xxx.py` files. (#2166)
* Refactor dead code. * First working step. * Remove a lot of duplicated code. * More dead code. * More cleanup. * Fix Santacoder test. * Fixing the simple tests. * Fixing sharding. * Fixes for VLM. * Fixing santacoder (num_kv_heads hardcoded). * Removing more dead code. * Fixing `config.n_head`. * Stopping earlier because of `<end_of_utterance>` in idefics2. * Addresses comments. * Removing the dead code. * Fuse back mistral into FlashCausalLM. * Finish removal. * Fixing docs + causal_lm `batch_class`. * Fixing docs + causal.lm. * Add default to Gemma Causality. * Default value for gemma/gemma2. * Wrong default.
This commit is contained in:
parent
c6bcadf883
commit
fb2f74e2b9
|
@ -10,7 +10,7 @@
|
||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "2.1.1-dev0"
|
"version": "2.1.2-dev0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
|
|
@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
||||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||||
|
|
|
@ -1,130 +1,124 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "eos_token",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 19,
|
||||||
"prefill": [],
|
"prefill": [],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 415,
|
"id": 415,
|
||||||
"logprob": -0.039886475,
|
"logprob": -0.03665161,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " The"
|
"text": " The"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12072,
|
"id": 12072,
|
||||||
"logprob": -0.1430664,
|
"logprob": -0.13549805,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " cow"
|
"text": " cow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 349,
|
"id": 349,
|
||||||
"logprob": -0.056488037,
|
"logprob": -0.05819702,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6328,
|
"id": 6328,
|
||||||
"logprob": -0.6855469,
|
"logprob": -0.6826172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " standing"
|
"text": " standing"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 356,
|
"id": 356,
|
||||||
"logprob": -0.1685791,
|
"logprob": -0.1607666,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " on"
|
"text": " on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 272,
|
"id": 272,
|
||||||
"logprob": -0.50097656,
|
"logprob": -0.5073242,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10305,
|
"id": 10305,
|
||||||
"logprob": -0.017303467,
|
"logprob": -0.016418457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " beach"
|
"text": " beach"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
"logprob": -1.3564453,
|
"logprob": -1.3916016,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " and"
|
"text": " and"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 272,
|
"id": 272,
|
||||||
"logprob": -0.017868042,
|
"logprob": -0.020217896,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13088,
|
"id": 13088,
|
||||||
"logprob": -0.0027103424,
|
"logprob": -0.0028133392,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " chicken"
|
"text": " chicken"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 349,
|
"id": 349,
|
||||||
"logprob": -0.003156662,
|
"logprob": -0.003145218,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6398,
|
"id": 6398,
|
||||||
"logprob": -0.37304688,
|
"logprob": -0.37060547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sitting"
|
"text": " sitting"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 356,
|
"id": 356,
|
||||||
"logprob": -0.034576416,
|
"logprob": -0.034851074,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " on"
|
"text": " on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.29418945,
|
"logprob": -0.2878418,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 17972,
|
"id": 17972,
|
||||||
"logprob": -0.042877197,
|
"logprob": -0.046051025,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " pile"
|
"text": " pile"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 302,
|
"id": 302,
|
||||||
"logprob": -0.00028443336,
|
"logprob": -0.00028848648,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2445,
|
"id": 2445,
|
||||||
"logprob": -0.023223877,
|
"logprob": -0.025772095,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " money"
|
"text": " money"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28723,
|
"id": 28723,
|
||||||
"logprob": -0.018157959,
|
"logprob": -0.018127441,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "."
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.00018393993,
|
"logprob": -0.00019824505,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2,
|
|
||||||
"logprob": -1.1920929e-07,
|
|
||||||
"special": true,
|
|
||||||
"text": "</s>"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
|
|
|
@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
), f"{repr(response.generated_text)}"
|
), f"{repr(response.generated_text)}"
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 19
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.utils import weight_hub_files, download_weights
|
from text_generation_server.utils import weight_hub_files, download_weights
|
||||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
||||||
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -16,7 +19,10 @@ def default_bloom():
|
||||||
revision = "main"
|
revision = "main"
|
||||||
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
||||||
download_weights(filenames, model_id, revision)
|
download_weights(filenames, model_id, revision)
|
||||||
return BLOOMSharded(model_id)
|
return BLOOMSharded(
|
||||||
|
model_id,
|
||||||
|
model_class=BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_causal_lm():
|
def default_causal_lm():
|
||||||
return CausalLM("gpt2")
|
return CausalLM.fallback("gpt2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_santacoder():
|
def default_santacoder():
|
||||||
return SantaCoder("bigcode/santacoder")
|
return CausalLM.fallback(model_id="bigcode/santacoder")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -20,7 +20,7 @@ def mt0_small_tokenizer():
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_seq2seq_lm():
|
def default_seq2seq_lm():
|
||||||
return Seq2SeqLM("bigscience/mt0-small")
|
return Seq2SeqLM.fallback("bigscience/mt0-small")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -11,17 +11,26 @@ from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
|
MPTForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.rw import RW
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.opt import OPTSharded
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
from text_generation_server.models.galactica import GalacticaSharded
|
GPTNeoxForCausalLM,
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
)
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
PhiConfig,
|
||||||
from text_generation_server.models.phi import Phi
|
PhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||||
|
T5ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
@ -41,9 +50,6 @@ __all__ = [
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
|
||||||
"OPTSharded",
|
|
||||||
"T5Sharded",
|
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -53,38 +59,65 @@ FLASH_ATTENTION = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
FlashLlamaForCausalLM,
|
||||||
from text_generation_server.models.flash_llama import (
|
|
||||||
FlashLlama,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_qwen2 import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashQwen2,
|
FlashCohereForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_cohere import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashCohere,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma import (
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
FlashGemma,
|
FlashGemma2ForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma2 import (
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
FlashGemma2,
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.pali_gemma import (
|
from text_generation_server.models.pali_gemma import (
|
||||||
PaliGemma,
|
PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_santacoder import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
FlashSantacoderSharded,
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.llava_next import LlavaNext
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
from text_generation_server.models.idefics2 import Idefics2
|
LlavaNextForConditionalGeneration,
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
)
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
FlashSantacoderForCausalLM,
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
|
FlashGPT2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
@ -93,21 +126,7 @@ except ImportError as e:
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(FlashGPT2)
|
|
||||||
__all__.append(FlashNeoXSharded)
|
|
||||||
__all__.append(FlashRWSharded)
|
|
||||||
__all__.append(FlashSantacoderSharded)
|
|
||||||
__all__.append(FlashLlama)
|
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
|
||||||
__all__.append(FlashMixtral)
|
|
||||||
__all__.append(FlashDbrx)
|
|
||||||
__all__.append(FlashPhi)
|
|
||||||
__all__.append(FlashQwen2)
|
|
||||||
__all__.append(FlashStarcoder2)
|
|
||||||
__all__.append(FlashGemma)
|
|
||||||
__all__.append(FlashGemma2)
|
|
||||||
__all__.append(FlashCohere)
|
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
|
@ -148,6 +167,11 @@ class ModelType(enum.Enum):
|
||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
"url": "https://huggingface.co/google/gemma-7b",
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
}
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
GEMMA2 = {
|
GEMMA2 = {
|
||||||
"type": "gemma2",
|
"type": "gemma2",
|
||||||
"name": "Gemma2",
|
"name": "Gemma2",
|
||||||
|
@ -445,13 +469,16 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_id.startswith("facebook/galactica"):
|
if model_id.startswith("facebook/galactica"):
|
||||||
return GalacticaSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
# Yes galactica is just an OPT model.
|
||||||
|
model_class=OPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=GalacticaCausalLMBatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -460,22 +487,26 @@ def get_model(
|
||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SantaCoder(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -483,38 +514,44 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == BLOOM:
|
if model_type == BLOOM:
|
||||||
return BLOOMSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=BloomForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == MPT:
|
elif model_type == MPT:
|
||||||
return MPTSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=MPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashGPT2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -525,7 +562,7 @@ def get_model(
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -535,25 +572,28 @@ def get_model(
|
||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashNeoXSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return GPTNeoxSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=GPTNeoxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -564,16 +604,18 @@ def get_model(
|
||||||
|
|
||||||
elif model_type == PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashPhi(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -588,9 +630,11 @@ def get_model(
|
||||||
"Legacy phi-msft is not supported with Flash Attention"
|
"Legacy phi-msft is not supported with Flash Attention"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return Phi(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=PhiForCausalLM,
|
||||||
|
config_class=PhiConfig,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -599,9 +643,10 @@ def get_model(
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -611,7 +656,7 @@ def get_model(
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -621,18 +666,22 @@ def get_model(
|
||||||
)
|
)
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -642,18 +691,22 @@ def get_model(
|
||||||
)
|
)
|
||||||
elif model_type == GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -664,18 +717,20 @@ def get_model(
|
||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -686,18 +741,23 @@ def get_model(
|
||||||
|
|
||||||
if model_type == DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashDbrx(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -711,27 +771,37 @@ def get_model(
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return RW(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -742,18 +812,20 @@ def get_model(
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMistral(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashMistralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -764,18 +836,20 @@ def get_model(
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMixtral(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -786,19 +860,22 @@ def get_model(
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashStarcoder2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -809,17 +886,20 @@ def get_model(
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashQwen2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -829,9 +909,10 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == OPT:
|
if model_type == OPT:
|
||||||
return OPTSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=OPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -839,13 +920,20 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == T5:
|
if model_type == T5:
|
||||||
return T5Sharded(
|
return Seq2SeqLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=T5ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
aliases={
|
||||||
|
"shared.weight": [
|
||||||
|
"encoder.embed_tokens.weight",
|
||||||
|
"decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if model_type == IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -861,34 +949,45 @@ def get_model(
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return Idefics2(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == "paligemma":
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return PaliGemma(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return LlavaNext(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
revision,
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -912,7 +1011,7 @@ def get_model(
|
||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -921,7 +1020,7 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -933,7 +1032,7 @@ def get_model(
|
||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -942,7 +1041,7 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
|
|
@ -4,22 +4,12 @@ import torch.distributed
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|
||||||
BloomForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
|
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||||
|
|
||||||
|
|
||||||
class BLOOMSharded(CausalLM):
|
class BLOOMSharded(CausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
slow_but_exact=False,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.pad_token_id = 3
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
prefix="transformer",
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = BloomForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return BloomCausalLMBatch
|
return BloomCausalLMBatch
|
||||||
|
|
|
@ -1,13 +1,25 @@
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -478,10 +490,87 @@ class CausalLMBatch(Batch):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CausalLMBatchKeysLast(Batch):
|
||||||
|
keys_head_dim_last: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
config_class=AutoConfig,
|
||||||
|
batch_class=CausalLMBatch,
|
||||||
|
):
|
||||||
|
self.batch_class = batch_class
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||||
|
)
|
||||||
|
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
model = model_class(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
|
@ -537,7 +626,12 @@ class CausalLM(Model):
|
||||||
else:
|
else:
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
self = cls.__new__(
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
self.batch_class = CausalLMBatch
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -545,15 +639,11 @@ class CausalLM(Model):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return self.batch_class
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
|
|
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
|
|
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
|
|
@ -464,8 +464,9 @@ class FlashSantacoderModel(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.transformer = FlashSantacoderModel(config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
|
|
|
@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.speculator = config.speculator
|
config.text_config.speculator = config.speculator
|
||||||
self.language_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
|
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
input_ids, inputs_embeds, image_features
|
input_ids, inputs_embeds, image_features
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
|
@ -10,7 +10,12 @@ import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
)
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
|
@ -21,6 +26,12 @@ from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
from text_generation_server.utils.dist import RANK
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
hub,
|
||||||
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
|
@ -799,29 +810,110 @@ class FlashCausalLMBatch(Batch):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
|
ADAPTER_LAYERS = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model: torch.nn.Module,
|
model_class,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
revision: Optional[str] = None,
|
||||||
num_layers: int,
|
quantize: Optional[str] = None,
|
||||||
num_kv_heads: int,
|
speculator: Optional[str] = None,
|
||||||
head_size: int,
|
dtype: Optional[torch.dtype] = None,
|
||||||
dtype: torch.dtype,
|
trust_remote_code: bool = False,
|
||||||
device: torch.device,
|
lora_adapter_ids: Optional[list] = [],
|
||||||
rank: int = 0,
|
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||||
world_size: int = 1,
|
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||||
sliding_window: Optional[int] = None,
|
default_dtype=torch.float16,
|
||||||
|
aliases=None,
|
||||||
|
# Used for Santacoder override of config
|
||||||
|
num_kv_heads=None,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.num_kv_heads = num_kv_heads
|
if torch.cuda.is_available():
|
||||||
self.head_size = head_size
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||||
|
# TODO Huge hack
|
||||||
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
|
set_sliding_window(config.sliding_window)
|
||||||
|
else:
|
||||||
|
config.sliding_window = None
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||||
|
)
|
||||||
|
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
prefix = ""
|
||||||
|
model = model_class(prefix, config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
# VLM models define the config we care about in their text_config
|
||||||
|
text_config = getattr(config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
config = text_config
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
# Validation is done in the model itself
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
if num_kv_heads is None:
|
||||||
|
# Final overide for GPT2
|
||||||
|
num_kv_heads = config.n_head
|
||||||
|
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
||||||
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -830,7 +922,7 @@ class FlashCausalLM(Model):
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1578,3 +1670,72 @@ class FlashCausalLM(Model):
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_adapter_loading(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||||
|
layer_weights = {}
|
||||||
|
|
||||||
|
prefix = "model.layers"
|
||||||
|
|
||||||
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
|
# that have a language_model inside of the larger model.
|
||||||
|
if hasattr(self.model, "language_model"):
|
||||||
|
_model = self.model.language_model
|
||||||
|
elif hasattr(self.model, "text_model"):
|
||||||
|
_model = self.model.text_model
|
||||||
|
else:
|
||||||
|
_model = self.model
|
||||||
|
|
||||||
|
for i, layer in enumerate(_model.model.layers):
|
||||||
|
layer_weights[(i, "q_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "k_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.k_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "v_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.v_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "o_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.o_proj",
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: this is a hack to avoid the gate_proj for
|
||||||
|
# FlashStarcoder2 that doesnt have these layers
|
||||||
|
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||||
|
layer_weights[(i, "gate_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.gate_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "up_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.up_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "down_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.down_proj",
|
||||||
|
layer.mlp.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||||
|
return layer_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adapter_layers(self) -> List[str]:
|
||||||
|
return ADAPTER_LAYERS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
|
return ["q_proj", "v_proj"]
|
||||||
|
|
||||||
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
|
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||||
|
|
||||||
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
|
||||||
FlashCohereForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashCohere(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashCohereForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashCohere, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,100 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
|
||||||
FlashDbrxForCausalLM,
|
|
||||||
DbrxConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashDbrx(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
|
||||||
|
|
||||||
try:
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
# FIXME: change back to model id once the tokenizer.json is merged
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
"Xenova/dbrx-instruct-tokenizer",
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = DbrxConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashDbrxForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashDbrx, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,83 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
||||||
FlashGemmaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,83 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import PretrainedConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
||||||
FlashGemma2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PretrainedConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,82 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
||||||
from transformers.models.gpt2 import GPT2Tokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
|
||||||
FlashGPT2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGPT2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,171 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
hub,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
]
|
|
||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
lora_adapter_ids: Optional[list] = [],
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
generation_config = GenerationConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
|
||||||
# TODO Huge hack
|
|
||||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashLlama, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "language_model"):
|
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
|
@ -1,24 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
from typing import Optional, Tuple, Dict, List
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
||||||
FlashMistralForCausalLM,
|
|
||||||
MistralConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
ADAPTER_LAYERS = [
|
||||||
|
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
|
||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashMistral(FlashCausalLM):
|
class FlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_cls,
|
|
||||||
model_id: str,
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
tokenizer_class=AutoTokenizer,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config_cls.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
else:
|
|
||||||
config.sliding_window = None
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = model_cls(prefix, config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
|
||||||
super().__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=num_layers,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.model.layers),
|
|
||||||
model.model.num_key_value_heads,
|
|
||||||
model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_adapter_loading(self) -> bool:
|
def supports_adapter_loading(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
# that have a language_model inside of the larger model.
|
# that have a language_model inside of the larger model.
|
||||||
if hasattr(self.model, "language_model"):
|
if hasattr(self.model, "text_model"):
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
_model = self.model.text_model
|
||||||
else:
|
else:
|
||||||
_model = self.model
|
_model = self.model
|
||||||
|
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
return layer_type in ROW_PARALLEL
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashMistral, self).__init__(
|
|
||||||
config_cls=MistralConfig,
|
|
||||||
model_cls=FlashMistralForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
|
||||||
MixtralConfig,
|
|
||||||
FlashMixtralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMixtral(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashMixtral, self).__init__(
|
|
||||||
config_cls=MixtralConfig,
|
|
||||||
model_cls=FlashMixtralForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
|
@ -1,82 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
|
||||||
FlashGPTNeoXForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoXSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashNeoXSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.gpt_neox.layers),
|
|
||||||
num_kv_heads=model.gpt_neox.num_heads,
|
|
||||||
head_size=model.gpt_neox.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,111 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
|
||||||
FlashPhiForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPhi(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashPhiForCausalLM(config, weights)
|
|
||||||
if speculator:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(speculator).exists() and Path(speculator).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(speculator) / "config.json")
|
|
||||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashPhi, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,93 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
|
||||||
BaseFlashMistral,
|
|
||||||
set_sliding_window,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
|
||||||
Qwen2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashQwen2(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if config.sliding_window is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = Qwen2ForCausalLM(config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(BaseFlashMistral, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
|
@ -1,91 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
||||||
RWConfig,
|
|
||||||
FlashRWForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = RWConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={
|
|
||||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
|
||||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashRWForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashRWSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=model.transformer.cache_size,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -1,99 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional, List
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
|
||||||
FlashSantacoderForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashSantacoderSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=1,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
|
@ -1,84 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
|
||||||
BaseFlashMistral,
|
|
||||||
set_sliding_window,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
|
||||||
Starcoder2Config,
|
|
||||||
FlashStarcoder2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Starcoder2 has the same base as Mistral
|
|
||||||
class FlashStarcoder2(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = Starcoder2Config.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if config.sliding_window is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashStarcoder2ForCausalLM(config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(BaseFlashMistral, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
|
@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
|
||||||
return GalacticaCausalLMBatch
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
|
||||||
GPTNeoxForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = GPTNeoxForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
|
@ -1,51 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
|
||||||
Idefics2ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
# XXX: Extremely important to cap resolution in order to limit
|
|
||||||
# VRAM usage.
|
|
||||||
size={"longest_edge": 448, "shortest_edge": 378},
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=Idefics2ForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
|
@ -1,46 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
|
||||||
LlavaNextForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNext(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=LlavaNextForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
|
@ -1,105 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Type
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
|
||||||
MPTForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MPTCausalLMBatch(CausalLMBatch):
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "CausalLMBatch":
|
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
|
||||||
batch.keys_head_dim_last = False
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
class MPTSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
# If model_id is a local path, load the file directly
|
|
||||||
local_path = Path(model_id, "config.json")
|
|
||||||
if local_path.exists():
|
|
||||||
filename = str(local_path.resolve())
|
|
||||||
else:
|
|
||||||
filename = hf_hub_download(
|
|
||||||
model_id, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
config = PretrainedConfig(**config)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
model = MPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=False,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
|
||||||
return MPTCausalLMBatch
|
|
|
@ -1,86 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
|
@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
|
||||||
class PaliGemma(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
model_cls=PaliGemmaForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self):
|
|
||||||
return PaliGemmaBatch
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
|
||||||
PhiConfig,
|
|
||||||
PhiForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Phi(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config = PhiConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.bos_token_id = config.bos_token_id
|
|
||||||
tokenizer.eos_token_id = config.eos_token_id
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
model = PhiForCausalLM(config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
|
@ -1,84 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class RW(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
if speculator:
|
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
device_map=(
|
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
|
||||||
if model.config.pad_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
|
||||||
elif model.config.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.eos_token_id
|
|
||||||
elif tokenizer.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
# Model Forward
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
|
@ -1,77 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional, List
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
|
||||||
FIM_SUFFIX = "<fim-suffix>"
|
|
||||||
FIM_PAD = "<fim-pad>"
|
|
||||||
EOD = "<|endoftext|>"
|
|
||||||
|
|
||||||
|
|
||||||
class SantaCoder(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"additional_special_tokens": [
|
|
||||||
EOD,
|
|
||||||
FIM_PREFIX,
|
|
||||||
FIM_MIDDLE,
|
|
||||||
FIM_SUFFIX,
|
|
||||||
FIM_PAD,
|
|
||||||
],
|
|
||||||
"pad_token": EOD,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
with device:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
|
@ -1,11 +1,22 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoConfig,
|
||||||
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
config_class=AutoConfig,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
aliases=None,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases=aliases,
|
||||||
|
)
|
||||||
|
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
model = model_class(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
|
@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
self = cls.__new__(
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
||||||
return Seq2SeqLMBatch
|
return Seq2SeqLMBatch
|
||||||
|
|
||||||
def decode(self, decoder_ids: List[int]) -> str:
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
@ -1,115 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models import Seq2SeqLM
|
|
||||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
|
||||||
T5ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class T5Sharded(Seq2SeqLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={
|
|
||||||
"shared.weight": [
|
|
||||||
"encoder.embed_tokens.weight",
|
|
||||||
"decoder.embed_tokens.weight",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
model = T5ForConditionalGeneration(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask: Optional,
|
|
||||||
encoder_last_hidden_state: Optional,
|
|
||||||
past_key_values: Optional = None,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor,
|
|
||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
||||||
]:
|
|
||||||
# Model Forward
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
|
||||||
encoder_outputs=encoder_last_hidden_state,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
outputs.logits,
|
|
||||||
speculative_logits,
|
|
||||||
outputs.encoder_last_hidden_state,
|
|
||||||
outputs.past_key_values,
|
|
||||||
)
|
|
|
@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
from text_generation_server.models.flash_mistral import (
|
FlashCausalLMBatch,
|
||||||
BaseFlashMistral,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(BaseFlashMistral):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
processor_class=AutoProcessor,
|
||||||
|
processor_kwargs=None,
|
||||||
|
batch_class=VlmCausalLMBatch,
|
||||||
|
revision,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if processor_kwargs is None:
|
||||||
|
processor_kwargs = {}
|
||||||
|
self.processor = processor_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
self.batch_class = batch_class
|
||||||
|
super().__init__(model_id=model_id, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return self.batch_class
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue