From befd9f6735ed8d7f5d8e9110b1f921e16d856a8b Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 30 Oct 2024 12:40:51 -0400 Subject: [PATCH] Support qwen2 vl (#2689) * feat: add support for qwen2 vl model * feat: fix token padding, enable warmup and process basic request * fix: improve get_position_ids, add lift embed_tokens * fix: remove get_cos_sin_hack dev function * feat: add simple test chat with meesage and text * fix: lint test * fix: adjust positional embeddings for multi dimensional position ids * fix: update docs and lint unused vars * fix: include linted file * fix: add norm after text output * fix: format model file * fix: adjust for ruff lints * fix: remove unused rotate_half * feat: refactors and calc num features * fix: prefer position_ids passed from vlm causal lm and reset ids on batch * fix: adjust get_position_ids if not available and add required args to signatures * fix: adjust resize case for qwen2_vl warmup * fix: avoid qwen2 vl specific paths with qwen2 --- docs/source/supported_models.md | 1 + .../test_flash_qwen2_vl_simple.json | 26 + .../models/test_flash_qwen2_vl.py | 42 ++ router/src/config.rs | 29 + router/src/validation.rs | 8 +- .../text_generation_server/layers/rotary.py | 2 + .../text_generation_server/models/__init__.py | 20 + .../flash_pali_gemma_modeling.py | 1 + .../custom_modeling/flash_qwen2_modeling.py | 42 +- .../models/custom_modeling/idefics2.py | 1 + .../models/custom_modeling/llava_next.py | 1 + .../models/custom_modeling/qwen2_vl.py | 509 ++++++++++++++++++ .../models/vlm_causal_lm.py | 33 ++ 13 files changed, 705 insertions(+), 10 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json create mode 100644 integration-tests/models/test_flash_qwen2_vl.py create mode 100644 server/text_generation_server/models/custom_modeling/qwen2_vl.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index ede1fc77..55449e47 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -24,6 +24,7 @@ Text Generation Inference enables serving optimized models. The following sectio - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) - [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f) +- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d) - [Opt](https://huggingface.co/facebook/opt-6.7b) - [T5](https://huggingface.co/google/flan-t5-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json new file mode 100644 index 00000000..2f7ffb08 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1730164250, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 58, + "prompt_tokens": 349, + "total_tokens": 407 + } +} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py new file mode 100644 index 00000000..357de2b1 --- /dev/null +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -0,0 +1,42 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_vl_handle(launcher): + with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_vl_handle): + await flash_qwen2_vl_handle.health(300) + return flash_qwen2_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + ) + + assert ( + response.choices[0].message.content + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + ) + + assert response == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index ce066ad0..9c31e6e8 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -138,10 +138,39 @@ impl Paligemma { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2VlVisionConfig { + pub(crate) depth: usize, + pub(crate) embed_dim: usize, + pub(crate) mlp_ratio: usize, + pub(crate) num_heads: usize, + pub(crate) in_chans: usize, + pub(crate) hidden_size: usize, + pub(crate) patch_size: usize, + pub(crate) spatial_merge_size: usize, + pub(crate) spatial_patch_size: usize, + pub(crate) temporal_patch_size: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2Vl { + pub(crate) vision_config: Qwen2VlVisionConfig, +} + +impl Qwen2Vl { + pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { + let num_pixels = height * width; + num_pixels / self.vision_config.patch_size.pow(2) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub enum Config { + Qwen2Vl(Qwen2Vl), LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, diff --git a/router/src/validation.rs b/router/src/validation.rs index 8159ede4..5b2a153c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -594,6 +594,10 @@ fn image_tokens( } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), + Qwen2Vl(config) => format!( + "<|vision_start|>{:?}<|vision_end|>", + "<|image_pad|>".repeat(config.get_number_of_features(height, width)) + ), _ => unimplemented!("Images tokens are not supported for this model configuration"), } } @@ -620,7 +624,9 @@ fn prepare_input( use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { - Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { + Some( + config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index a2076bb2..123bbadb 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -89,6 +89,8 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "default": + pass elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 99e3d343..6c633521 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -146,6 +146,9 @@ try: from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") @@ -275,6 +278,11 @@ class ModelType(enum.Enum): "name": "Qwen 2", "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } OPT = { "type": "opt", "name": "Opt", @@ -1193,6 +1201,18 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == QWEN2_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb..b1f89eff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index ab2a177d..cc4039b1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -61,6 +61,11 @@ class Qwen2Attention(torch.nn.Module): config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads + self.mrope_section = ( + config.rope_scaling.get("mrope_section", None) + if config.rope_scaling is not None + else None + ) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -122,6 +127,17 @@ class Qwen2Attention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + if self.mrope_section is not None: + # if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: @@ -270,9 +286,6 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -296,7 +309,7 @@ class Qwen2Model(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -307,13 +320,16 @@ class Qwen2Model(torch.nn.Module): true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = inputs_embeds - # Get rotary cos and sin for this forward - # Avoid to index in each layer + # flatten position ids from 2D to 1D cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype + position_ids.flatten(), true_max_s, hidden_states.dtype ) + # reshape back to 2D if the position_ids were 2D + if position_ids.size(0) != cos.size(0): + cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers): @@ -352,6 +368,12 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -382,8 +404,10 @@ class Qwen2ForCausalLM(torch.nn.Module): # kernel requires the true values seqlen = seqlen.clamp(max=self.max_past_tensor) + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a829c374..923123d6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module): # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 32e9d334..df7366ea 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -180,6 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module): pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 00000000..6ebc3d4e --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,509 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + pass +else: + pass + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + self.proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + # TODO: make use of existing RotatoryPositionEmbedding class + + # create the attention mask + attention_mask = torch.zeros( + [1, hidden_state.shape[0], hidden_state.shape[0]], + device=hidden_state.device, + dtype=torch.bool, + ) + # TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it + + # apply the cu_seqlens to the attention mask + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + + # transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # apply attention + attn_output = F.scaled_dot_product_attention( + query, key, value, attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + # TODO: prefer flash attention + + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLSdpaAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states_post_norm1, res = self.norm1(hidden_states) + hidden_states = hidden_states + self.attn( + hidden_states_post_norm1, cu_seqlens, rotary_pos_emb + ) + hidden_states_post_norm2, res = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.embed_dim * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states, grid_thw) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: Optional[torch.Tensor] = None, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states, grid_thw) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + self.lm_head = FastLinear.load( + prefix="lm_head", weights=weights, config=config, bias=False + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.device = weights.device + + def get_position_ids( + self, + batch_input_ids: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor], + # video_grid_thw is not implemented yet as we do not accept video inputs at the moment + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = torch.ones( + 3, + batch_input_ids.shape[0], + batch_input_ids.shape[1], + dtype=batch_input_ids.dtype, + device=batch_input_ids.device, + ) + d = batch_input_ids.device + if image_grid_thw is not None: + image_index = 0 + llm_pos_ids_list = [] + + for i, input_ids in enumerate(batch_input_ids): + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + # only copy the sum of the image tokens GPU<->CPU + image_count = (vision_tokens == self.image_token_id).sum().item() + + current_pos = 0 + for _ in range(image_count): + # copy the value position of the next image token from GPU<->CPU + next_image_pos = ( + (input_ids[current_pos:] == self.image_token_id) + .nonzero()[0] + .item() + ) + # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop + time_steps, height, width = image_grid_thw[image_index].clone() + height //= self.spatial_merge_size + width //= self.spatial_merge_size + + # calculate the length of the text and image tokens + text_length = next_image_pos - current_pos + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + ) + + # text position ids + text_pos_ids = torch.arange(text_length, device=d) + text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx + llm_pos_ids_list.append(text_pos_ids) + + # image position ids + t_indices = torch.arange(time_steps, device=d).repeat_interleave( + height * width + ) + h_indices = ( + torch.arange(height, device=d) + .repeat_interleave(width) + .repeat(time_steps) + ) + w_indices = torch.arange(width, device=d).repeat( + height * time_steps + ) + + image_pos_ids = ( + torch.stack([t_indices, h_indices, w_indices]) + + text_length + + start_idx + ) + llm_pos_ids_list.append(image_pos_ids) + + current_pos = next_image_pos + time_steps * height * width + image_index += 1 + + if current_pos < batch_input_ids.size(1): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + text_len = batch_input_ids.size(1) - current_pos + llm_pos_ids_list.append( + torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[:, i, :] = llm_positions.to(position_ids.device) + + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + hidden_states, _ = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + return logits, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4bbddcfb..9a3db502 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -67,6 +67,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + num_pads = image_input.pixel_values.shape[0] // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -137,6 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] @classmethod @tracer.start_as_current_span("concatenate") @@ -145,6 +150,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @tracer.start_as_current_span("filter") @@ -153,6 +159,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @classmethod @@ -170,6 +177,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch): pass elif chunk_type == "image": image = Image.open(BytesIO(chunk.image.data)) + # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the + # default warmup image is 20x20 + if config.model_type == "qwen2_vl": + if image.width <= 20: + w = image.width * 2 + h = image.height * 2 + image = image.resize((w, h)) + if config.model_type == "llava_next": images.append(image) else: @@ -237,10 +252,15 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None else: batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @@ -343,6 +363,16 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices + if hasattr(self.model, "get_position_ids"): + if position_ids.shape[0] != 1: + position_ids = self.model.get_position_ids( + input_ids.unsqueeze(0), batch.image_grid_thw + ) + batch.position_ids = position_ids[0, 0, :] + else: + position_ids = position_ids.repeat(3, 1, 1).clone() + batch.position_ids = position_ids[0, 0, :] + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. @@ -394,6 +424,7 @@ class VlmCausalLM(FlashCausalLM): pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -403,6 +434,8 @@ class VlmCausalLM(FlashCausalLM): batch.pixel_attention_mask = None if batch.image_sizes is not None: batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph