diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json new file mode 100644 index 00000000..f9a414fa --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1730416361, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.4.1-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 357de2b1..946ab2f1 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher): - with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle @@ -40,3 +40,41 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): ) assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): + responses = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert ( + generated + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + ) + assert count == 58 + assert last_response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 6ebc3d4e..5936c6fe 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -34,7 +34,7 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, TensorParallelEmbedding, - FastLinear, + SpeculativeHead, ) from text_generation_server.layers.attention import ( Seqlen, @@ -69,7 +69,7 @@ def apply_rotary_pos_emb_vision( class Qwen2VLSdpaAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() - self.embed_dim = config.embed_dim + self.embed_dim = config.embed_dim // weights.process_group.size() self.head_dim = config.hidden_size // config.num_heads self.num_heads = config.num_heads // weights.process_group.size() @@ -82,7 +82,7 @@ class Qwen2VLSdpaAttention(nn.Module): num_key_value_heads=self.num_heads, ) self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) - self.proj = TensorParallelColumnLinear.load( + self.proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.proj", weights=weights, @@ -364,8 +364,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefix="visual", config=config.vision_config, weights=weights ) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) - self.lm_head = FastLinear.load( - prefix="lm_head", weights=weights, config=config, bias=False + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, ) self.norm = FastRMSNorm.load( prefix="model.norm", @@ -377,9 +384,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_position_ids( self, batch_input_ids: torch.Tensor, - image_grid_thw: Optional[torch.LongTensor], + image_grid_thw: Optional[torch.LongTensor] = None, # video_grid_thw is not implemented yet as we do not accept video inputs at the moment ) -> Tuple[torch.Tensor, torch.Tensor]: + if batch_input_ids.dim() == 1: + batch_input_ids = batch_input_ids.unsqueeze(0) + position_ids = torch.ones( 3, batch_input_ids.shape[0], @@ -505,5 +515,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefill_cache_indices=prefill_cache_indices, ) hidden_states, _ = self.norm(hidden_states) - logits = self.lm_head(hidden_states) - return logits, None + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8ab1a811..52ab5d6a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1430,6 +1430,14 @@ class FlashCausalLM(Model): else: state = None + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "model_type") + and self.model.config.model_type == "qwen2_vl" + ): + if position_ids.dim() == 1: + position_ids = self.model.get_position_ids(input_ids) + graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1806,7 +1814,7 @@ class FlashCausalLM(Model): # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1981,7 +1989,7 @@ class FlashCausalLM(Model): # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[indices] + batch.position_ids = batch.position_ids[(..., indices)] batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 9a3db502..aa0fe107 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -363,15 +363,12 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if hasattr(self.model, "get_position_ids"): - if position_ids.shape[0] != 1: + if self.model.config.model_type == "qwen2_vl": + if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( - input_ids.unsqueeze(0), batch.image_grid_thw + input_ids, batch.image_grid_thw ) - batch.position_ids = position_ids[0, 0, :] - else: - position_ids = position_ids.repeat(3, 1, 1).clone() - batch.position_ids = position_ids[0, 0, :] + batch.position_ids = position_ids if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache