fix cuda graphs for qwen2-vl (#2708)
* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl * fix: only check model type if config exists * fix: adjust sharding and lm head logic * fix qwen2 failure in intel cpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix: return correct shape logits and add streaming test * fix: remove unused import and refactor test --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
befd9f6735
commit
01dacf8e8f
|
@ -0,0 +1,20 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1730416361,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
|
@ -3,7 +3,7 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_qwen2_vl_handle(launcher):
|
def flash_qwen2_vl_handle(launcher):
|
||||||
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
|
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,3 +40,41 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
|
||||||
|
responses = await flash_qwen2.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe this image."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
assert (
|
||||||
|
generated
|
||||||
|
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||||
|
)
|
||||||
|
assert count == 58
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
|
|
@ -34,7 +34,7 @@ from text_generation_server.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastLinear,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
@ -69,7 +69,7 @@ def apply_rotary_pos_emb_vision(
|
||||||
class Qwen2VLSdpaAttention(nn.Module):
|
class Qwen2VLSdpaAttention(nn.Module):
|
||||||
def __init__(self, *, prefix, config, weights):
|
def __init__(self, *, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.embed_dim
|
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||||
self.head_dim = config.hidden_size // config.num_heads
|
self.head_dim = config.hidden_size // config.num_heads
|
||||||
self.num_heads = config.num_heads // weights.process_group.size()
|
self.num_heads = config.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class Qwen2VLSdpaAttention(nn.Module):
|
||||||
num_key_value_heads=self.num_heads,
|
num_key_value_heads=self.num_heads,
|
||||||
)
|
)
|
||||||
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||||
self.proj = TensorParallelColumnLinear.load(
|
self.proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -364,8 +364,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
prefix="visual", config=config.vision_config, weights=weights
|
prefix="visual", config=config.vision_config, weights=weights
|
||||||
)
|
)
|
||||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||||
self.lm_head = FastLinear.load(
|
if config.tie_word_embeddings:
|
||||||
prefix="lm_head", weights=weights, config=config, bias=False
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm",
|
prefix="model.norm",
|
||||||
|
@ -377,9 +384,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
def get_position_ids(
|
def get_position_ids(
|
||||||
self,
|
self,
|
||||||
batch_input_ids: torch.Tensor,
|
batch_input_ids: torch.Tensor,
|
||||||
image_grid_thw: Optional[torch.LongTensor],
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if batch_input_ids.dim() == 1:
|
||||||
|
batch_input_ids = batch_input_ids.unsqueeze(0)
|
||||||
|
|
||||||
position_ids = torch.ones(
|
position_ids = torch.ones(
|
||||||
3,
|
3,
|
||||||
batch_input_ids.shape[0],
|
batch_input_ids.shape[0],
|
||||||
|
@ -505,5 +515,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
logits = self.lm_head(hidden_states)
|
if lm_head_indices is not None:
|
||||||
return logits, None
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
|
@ -1430,6 +1430,14 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.model, "config")
|
||||||
|
and hasattr(self.model.config, "model_type")
|
||||||
|
and self.model.config.model_type == "qwen2_vl"
|
||||||
|
):
|
||||||
|
if position_ids.dim() == 1:
|
||||||
|
position_ids = self.model.get_position_ids(input_ids)
|
||||||
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
@ -1806,7 +1814,7 @@ class FlashCausalLM(Model):
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
@ -1981,7 +1989,7 @@ class FlashCausalLM(Model):
|
||||||
# instantly become of shape [BATCH_SIZE]
|
# instantly become of shape [BATCH_SIZE]
|
||||||
if prefill and finished_prefilling:
|
if prefill and finished_prefilling:
|
||||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||||
batch.position_ids = batch.position_ids[indices]
|
batch.position_ids = batch.position_ids[(..., indices)]
|
||||||
batch.slot_indices = batch.slot_indices[indices]
|
batch.slot_indices = batch.slot_indices[indices]
|
||||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||||
indices
|
indices
|
||||||
|
|
|
@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if hasattr(self.model, "get_position_ids"):
|
if self.model.config.model_type == "qwen2_vl":
|
||||||
if position_ids.shape[0] != 1:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids.unsqueeze(0), batch.image_grid_thw
|
input_ids, batch.image_grid_thw
|
||||||
)
|
)
|
||||||
batch.position_ids = position_ids[0, 0, :]
|
batch.position_ids = position_ids
|
||||||
else:
|
|
||||||
position_ids = position_ids.repeat(3, 1, 1).clone()
|
|
||||||
batch.position_ids = position_ids[0, 0, :]
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
|
Loading…
Reference in New Issue