feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller

This commit is contained in:
drbh 2024-12-06 00:54:20 +00:00
parent fd4de85283
commit 3cc82978f9
4 changed files with 39 additions and 2 deletions

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The correct answer is: blue",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1733445131,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 27,
"total_tokens": 34
}
}

View File

@ -5,7 +5,7 @@ import pytest
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-2B-Instruct",
max_input_tokens=40,
max_input_length=40,
max_batch_prefill_tokens=50,
max_total_tokens=51,
) as handle:

View File

@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
import text_generation_server.models.globals as globals
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
@ -1208,6 +1209,11 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
logger.warning(
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
)
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,

View File

@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module):
dim=-1,
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(
query,
torch.select(kv, dim=1, index=0),
cos[: query.shape[0], ...],
sin[: query.shape[0], ...],
)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]