fix cuda graphs for qwen2-vl (#2708)

* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

* fix: only check model type if config exists

* fix: adjust sharding and lm head logic

* fix qwen2 failure in intel cpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix: return correct shape logits and add streaming test

* fix: remove unused import and refactor test

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
drbh 2024-10-31 22:05:34 -04:00 committed by GitHub
parent befd9f6735
commit 01dacf8e8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 18 deletions

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1730416361,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.1-dev0-native",
"usage": null
}

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher): def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
yield handle yield handle
@ -40,3 +40,41 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
) )
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
responses = await flash_qwen2.chat(
max_tokens=100,
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
stream=True,
)
count = 0
generated = ""
last_response = None
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
assert (
generated
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)
assert count == 58
assert last_response == response_snapshot

View File

@ -34,7 +34,7 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLinear, SpeculativeHead,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen, Seqlen,
@ -69,7 +69,7 @@ def apply_rotary_pos_emb_vision(
class Qwen2VLSdpaAttention(nn.Module): class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights): def __init__(self, *, prefix, config, weights):
super().__init__() super().__init__()
self.embed_dim = config.embed_dim self.embed_dim = config.embed_dim // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size() self.num_heads = config.num_heads // weights.process_group.size()
@ -82,7 +82,7 @@ class Qwen2VLSdpaAttention(nn.Module):
num_key_value_heads=self.num_heads, num_key_value_heads=self.num_heads,
) )
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelColumnLinear.load( self.proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
weights=weights, weights=weights,
@ -364,8 +364,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix="visual", config=config.vision_config, weights=weights prefix="visual", config=config.vision_config, weights=weights
) )
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
self.lm_head = FastLinear.load( if config.tie_word_embeddings:
prefix="lm_head", weights=weights, config=config, bias=False suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", prefix="model.norm",
@ -377,9 +384,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_position_ids( def get_position_ids(
self, self,
batch_input_ids: torch.Tensor, batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor], image_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment # video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_input_ids.dim() == 1:
batch_input_ids = batch_input_ids.unsqueeze(0)
position_ids = torch.ones( position_ids = torch.ones(
3, 3,
batch_input_ids.shape[0], batch_input_ids.shape[0],
@ -505,5 +515,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
logits = self.lm_head(hidden_states) if lm_head_indices is not None:
return logits, None hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -1430,6 +1430,14 @@ class FlashCausalLM(Model):
else: else:
state = None state = None
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "model_type")
and self.model.config.model_type == "qwen2_vl"
):
if position_ids.dim() == 1:
position_ids = self.model.get_position_ids(input_ids)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
@ -1806,7 +1814,7 @@ class FlashCausalLM(Model):
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
@ -1981,7 +1989,7 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE] # instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling: if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1 indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices] batch.position_ids = batch.position_ids[(..., indices)]
batch.slot_indices = batch.slot_indices[indices] batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices indices

View File

@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if hasattr(self.model, "get_position_ids"): if self.model.config.model_type == "qwen2_vl":
if position_ids.shape[0] != 1: if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw input_ids, batch.image_grid_thw
) )
batch.position_ids = position_ids[0, 0, :] batch.position_ids = position_ids
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache