Support qwen2 vl (#2689)

* feat: add support for qwen2 vl model

* feat: fix token padding, enable warmup and process basic request

* fix: improve get_position_ids, add lift embed_tokens

* fix: remove get_cos_sin_hack dev function

* feat: add simple test chat with meesage and text

* fix: lint test

* fix: adjust positional embeddings for multi dimensional position ids

* fix: update docs and lint unused vars

* fix: include linted file

* fix: add norm after text output

* fix: format model file

* fix: adjust for ruff lints

* fix: remove unused rotate_half

* feat: refactors and calc num features

* fix: prefer position_ids passed from vlm causal lm and reset ids on batch

* fix: adjust get_position_ids if not available and add required args to signatures

* fix: adjust resize case for qwen2_vl warmup

* fix: avoid qwen2 vl specific paths with qwen2
This commit is contained in:
drbh 2024-10-30 12:40:51 -04:00 committed by GitHub
parent 46aeb0860d
commit befd9f6735
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 705 additions and 10 deletions

View File

@ -24,6 +24,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
- [Opt](https://huggingface.co/facebook/opt-6.7b)
- [T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1730164250,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 58,
"prompt_tokens": 349,
"total_tokens": 407
}
}

View File

@ -0,0 +1,42 @@
import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client
@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
max_tokens=100,
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
)
assert (
response.choices[0].message.content
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)
assert response == response_snapshot

View File

@ -138,10 +138,39 @@ impl Paligemma {
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2VlVisionConfig {
pub(crate) depth: usize,
pub(crate) embed_dim: usize,
pub(crate) mlp_ratio: usize,
pub(crate) num_heads: usize,
pub(crate) in_chans: usize,
pub(crate) hidden_size: usize,
pub(crate) patch_size: usize,
pub(crate) spatial_merge_size: usize,
pub(crate) spatial_patch_size: usize,
pub(crate) temporal_patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2Vl {
pub(crate) vision_config: Qwen2VlVisionConfig,
}
impl Qwen2Vl {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
num_pixels / self.vision_config.patch_size.pow(2)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
Qwen2Vl(Qwen2Vl),
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,

View File

@ -594,6 +594,10 @@ fn image_tokens(
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
@ -620,7 +624,9 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;

View File

@ -89,6 +89,8 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(

View File

@ -146,6 +146,9 @@ try:
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
@ -275,6 +278,11 @@ class ModelType(enum.Enum):
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
OPT = {
"type": "opt",
"name": "Opt",
@ -1193,6 +1201,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(

View File

@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -61,6 +61,11 @@ class Qwen2Attention(torch.nn.Module):
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = (
config.rope_scaling.get("mrope_section", None)
if config.rope_scaling is not None
else None
)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -122,6 +127,17 @@ class Qwen2Attention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
if self.mrope_section is not None:
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
@ -270,9 +286,6 @@ class Qwen2Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
@ -296,7 +309,7 @@ class Qwen2Model(torch.nn.Module):
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -307,13 +320,16 @@ class Qwen2Model(torch.nn.Module):
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
# flatten position ids from 2D to 1D
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
position_ids.flatten(), true_max_s, hidden_states.dtype
)
# reshape back to 2D if the position_ids were 2D
if position_ids.size(0) != cos.size(0):
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
residual = None
for i, layer in enumerate(self.layers):
@ -352,6 +368,12 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
@ -382,8 +404,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_ids,
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,

View File

@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:

View File

@ -180,6 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:

View File

@ -0,0 +1,509 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
pass
else:
pass
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.embed_dim
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# TODO: make use of existing RotatoryPositionEmbedding class
# create the attention mask
attention_mask = torch.zeros(
[1, hidden_state.shape[0], hidden_state.shape[0]],
device=hidden_state.device,
dtype=torch.bool,
)
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
# apply the cu_seqlens to the attention mask
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
# apply attention
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
# TODO: prefer flash attention
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2VLSdpaAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastLayerNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastLayerNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
hidden_states_post_norm1, res = self.norm1(hidden_states)
hidden_states = hidden_states + self.attn(
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
)
hidden_states_post_norm2, res = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
return hidden_states
class Qwen2VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastLayerNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_chans,
out_channels=config.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.embed_dim // config.num_heads
# TODO: replace with static positional embeddings once implemented
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.embed_dim
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
# iterately apply the blocks to the hidden states
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states, grid_thw)
return hidden_states
class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
self.lm_head = FastLinear.load(
prefix="lm_head", weights=weights, config=config, bias=False
)
self.norm = FastRMSNorm.load(
prefix="model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.device = weights.device
def get_position_ids(
self,
batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor],
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = torch.ones(
3,
batch_input_ids.shape[0],
batch_input_ids.shape[1],
dtype=batch_input_ids.dtype,
device=batch_input_ids.device,
)
d = batch_input_ids.device
if image_grid_thw is not None:
image_index = 0
llm_pos_ids_list = []
for i, input_ids in enumerate(batch_input_ids):
vision_start_indices = torch.argwhere(
input_ids == self.vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# only copy the sum of the image tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item()
current_pos = 0
for _ in range(image_count):
# copy the value position of the next image token from GPU<->CPU
next_image_pos = (
(input_ids[current_pos:] == self.image_token_id)
.nonzero()[0]
.item()
)
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index].clone()
height //= self.spatial_merge_size
width //= self.spatial_merge_size
# calculate the length of the text and image tokens
text_length = next_image_pos - current_pos
start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# text position ids
text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids)
# image position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width
)
h_indices = (
torch.arange(height, device=d)
.repeat_interleave(width)
.repeat(time_steps)
)
w_indices = torch.arange(width, device=d).repeat(
height * time_steps
)
image_pos_ids = (
torch.stack([t_indices, h_indices, w_indices])
+ text_length
+ start_idx
)
llm_pos_ids_list.append(image_pos_ids)
current_pos = next_image_pos + time_steps * height * width
image_index += 1
if current_pos < batch_input_ids.size(1):
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[:, i, :] = llm_positions.to(position_ids.device)
return position_ids
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits, None

View File

@ -67,6 +67,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl":
num_pads = image_input.pixel_values.shape[0] // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -137,6 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
@ -145,6 +150,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
@ -153,6 +159,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@classmethod
@ -170,6 +177,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type == "qwen2_vl":
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
else:
@ -237,10 +252,15 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@ -343,6 +363,16 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if hasattr(self.model, "get_position_ids"):
if position_ids.shape[0] != 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
)
batch.position_ids = position_ids[0, 0, :]
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
@ -394,6 +424,7 @@ class VlmCausalLM(FlashCausalLM):
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
@ -403,6 +434,8 @@ class VlmCausalLM(FlashCausalLM):
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph