fix cuda graphs for qwen2-vl (#2708)
* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl * fix: only check model type if config exists * fix: adjust sharding and lm head logic * fix qwen2 failure in intel cpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix: return correct shape logits and add streaming test * fix: remove unused import and refactor test --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
befd9f6735
commit
01dacf8e8f
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1730416361,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.4.1-dev0-native",
|
||||
"usage": null
|
||||
}
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -40,3 +40,41 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
|||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
|
||||
responses = await flash_qwen2.chat(
|
||||
max_tokens=100,
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
|
||||
assert (
|
||||
generated
|
||||
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||
)
|
||||
assert count == 58
|
||||
assert last_response == response_snapshot
|
||||
|
|
|
@ -34,7 +34,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
FastLinear,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
|
@ -69,7 +69,7 @@ def apply_rotary_pos_emb_vision(
|
|||
class Qwen2VLSdpaAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.embed_dim
|
||||
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.num_heads = config.num_heads // weights.process_group.size()
|
||||
|
||||
|
@ -82,7 +82,7 @@ class Qwen2VLSdpaAttention(nn.Module):
|
|||
num_key_value_heads=self.num_heads,
|
||||
)
|
||||
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||
self.proj = TensorParallelColumnLinear.load(
|
||||
self.proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj",
|
||||
weights=weights,
|
||||
|
@ -364,8 +364,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
prefix="visual", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||
self.lm_head = FastLinear.load(
|
||||
prefix="lm_head", weights=weights, config=config, bias=False
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm",
|
||||
|
@ -377,9 +384,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
def get_position_ids(
|
||||
self,
|
||||
batch_input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.LongTensor],
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if batch_input_ids.dim() == 1:
|
||||
batch_input_ids = batch_input_ids.unsqueeze(0)
|
||||
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
batch_input_ids.shape[0],
|
||||
|
@ -505,5 +515,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits, None
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -1430,6 +1430,14 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
state = None
|
||||
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "model_type")
|
||||
and self.model.config.model_type == "qwen2_vl"
|
||||
):
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = self.model.get_position_ids(input_ids)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
|
@ -1806,7 +1814,7 @@ class FlashCausalLM(Model):
|
|||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
|
@ -1981,7 +1989,7 @@ class FlashCausalLM(Model):
|
|||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
batch.position_ids = batch.position_ids[(..., indices)]
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
|
|
|
@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
|
|||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if hasattr(self.model, "get_position_ids"):
|
||||
if position_ids.shape[0] != 1:
|
||||
if self.model.config.model_type == "qwen2_vl":
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids.unsqueeze(0), batch.image_grid_thw
|
||||
input_ids, batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids[0, 0, :]
|
||||
else:
|
||||
position_ids = position_ids.repeat(3, 1, 1).clone()
|
||||
batch.position_ids = position_ids[0, 0, :]
|
||||
batch.position_ids = position_ids
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
|
|
Loading…
Reference in New Issue