feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller
This commit is contained in:
parent
fd4de85283
commit
3cc82978f9
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "The correct answer is: blue",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1733445131,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 7,
|
||||
"prompt_tokens": 27,
|
||||
"total_tokens": 34
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
max_input_tokens=40,
|
||||
max_input_length=40,
|
||||
max_batch_prefill_tokens=50,
|
||||
max_total_tokens=51,
|
||||
) as handle:
|
||||
|
|
|
@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import text_generation_server.models.globals as globals
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
|
@ -1208,6 +1209,11 @@ def get_model(
|
|||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == QWEN2_VL:
|
||||
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
|
||||
logger.warning(
|
||||
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
|
||||
)
|
||||
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2VLForConditionalGeneration,
|
||||
|
|
|
@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module):
|
|||
dim=-1,
|
||||
)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
cos[: query.shape[0], ...],
|
||||
sin[: query.shape[0], ...],
|
||||
)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
|
|
Loading…
Reference in New Issue