feat: support video input chunks and enable qwen2 vl to process video

This commit is contained in:
David Holtz 2024-11-18 21:16:21 +00:00
parent 6b4697e9d1
commit a9c2d28a3a
10 changed files with 310 additions and 63 deletions

View File

@ -9,7 +9,7 @@ use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
pub use v3::{Chunk, Image, Input, InputChunk, Video};
#[async_trait]
pub trait Health {
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Some(Chunk::Video(url)) => {
output.push_str(&format!("<video>({})", url))
Some(Chunk::Video(Video { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
StoppingCriteriaParameters, Tokens, Video,
};
pub use sharded_client::ShardedClient;

View File

@ -15,7 +15,7 @@ pub use grpc_client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
StoppingCriteriaParameters, Video,
};
pub use sharded_client::ShardedClient;

View File

@ -439,7 +439,10 @@ impl State {
data: image.data,
mimetype: image.mimetype,
}),
Chunk::Video(url) => client::Chunk::Video(url),
Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.data,
mimetype: video.mimetype,
}),
}),
})
.collect(),

View File

@ -1926,6 +1926,24 @@
]
}
}
},
{
"type": "object",
"required": [
"video_url",
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"video_url"
]
},
"video_url": {
"$ref": "#/components/schemas/Url"
}
}
}
],
"discriminator": {

View File

@ -62,16 +62,15 @@ Options:
[env: QUANTIZE=]
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
```
## SPECULATE

View File

@ -1191,7 +1191,9 @@ impl From<Message> for TextMessage {
.map(|chunk| match chunk {
MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
MessageChunk::VideoUrl { video_url } => format!("![]({})", video_url.url),
MessageChunk::VideoUrl { video_url } => {
format!("<video>({})", video_url.url)
}
})
.collect::<Vec<_>>()
.join(""),

View File

@ -516,6 +516,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
let (data, mimetype) =
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
(data, "video/mp4".to_string())
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let data = STANDARD.decode(&content["base64,".len()..])?;
(data, mimetype.to_string())
} else {
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};
let mut cursor = Cursor::new(&data);
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
let video_track = context
.tracks
.iter()
.find(|track| track.track_type == mp4parse::TrackType::Video)
.ok_or(ValidationError::NoVideoStream)?;
let video_info = video_track
.tkhd
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let width = (video_info.width >> 16) as usize;
let height = (video_info.height >> 16) as usize;
// timescale units per second
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
// TODO: revisit if we need duration in seconds
let _duration = video_track
.duration
.map(|d| d.0 as f32 / timescale)
.unwrap_or(0.0);
let time_to_sample = video_track
.stts
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let num_samples = time_to_sample
.samples
.iter()
.map(|entry| entry.sample_count)
.sum::<u32>();
let total_frames = num_samples as f32;
Ok((data, mimetype, height, width, total_frames))
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
@ -604,6 +668,31 @@ fn image_tokens(
}
}
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
use Config::*;
match config {
// TOOD: improve to use the config to better estimate the number of tokens
Qwen2Vl(_config) => {
let video_fps = 30_f32;
let fps = 30_f32;
let min_frames = 16_f32;
let max_frames = 64_f32;
// make sure the frames are within the range and are even
let nframes = (total_frames / video_fps * fps)
.max(min_frames)
.min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height * width / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
)
}
_ => unimplemented!("Video tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String {
match config {
Config::Idefics2(_) => {
@ -626,7 +715,8 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
// Add video regex
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
static VIDEO_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(
@ -636,7 +726,7 @@ fn prepare_input<T: TokenizerTrait>(
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
// Process videos first
// handle video content first
for chunk in VIDEO_RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
@ -644,14 +734,15 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
input_chunks.push(Chunk::Video(video_url.to_string()));
// For videos, we use the default size as height/width don't matter for the initial processing
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
let (data, mimetype, height, width, total_frames) =
fetch_video(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Video(Video { data, mimetype }));
let video_tokens = video_tokens(config, height, width, total_frames);
tokenizer_query.push_str(&video_tokens);
start = chunk_end;
}
// Process remaining content for images
// clip remaining inputs and process images
let remaining_input = &inputs[start..];
for chunk in RE.find_iter(remaining_input) {
let chunk_start = chunk.start() + start;
@ -699,11 +790,17 @@ pub struct Image {
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image),
Video(String),
Video(Video),
}
/// Convert input chunks to a stringly-typed input for backwards
@ -722,8 +819,9 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Chunk::Video(url) => {
output.push_str(&format!("<video>({})", url))
Chunk::Video(Video { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
}
});
output
@ -851,6 +949,10 @@ pub enum ValidationError {
UnsupportedModality(&'static str),
#[error("invalid video content: {0}")]
InvalidVideoContent(String),
#[error("could not parse MP4 file")]
MP4Error,
#[error("no video stream found")]
NoVideoStream,
}
#[cfg(test)]

View File

@ -14,20 +14,12 @@
# limitations under the License.
"""PyTorch Qwen2 VL model."""
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
@ -444,39 +436,49 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids == self.vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# Count both image and video tokens
# only copy the sum of the image and video tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item()
video_count = (vision_tokens == self.video_token_id).sum().item()
current_pos = 0
for _ in range(image_count + video_count):
# Find next vision token position (either image or video)
# copy the value position of the next image or video token from GPU<->CPU
next_vision_pos = (
((input_ids[current_pos:] == self.image_token_id) |
(input_ids[current_pos:] == self.video_token_id))
(
(input_ids[current_pos:] == self.image_token_id)
| (input_ids[current_pos:] == self.video_token_id)
)
.nonzero()[0]
.item()
)
# Determine if current token is video or image
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
is_video = (
input_ids[current_pos + next_vision_pos] == self.video_token_id
)
grid_thw = (
video_grid_thw[vision_index]
if is_video
else image_grid_thw[vision_index]
)
time_steps, height, width = grid_thw.clone()
height //= self.spatial_merge_size
width //= self.spatial_merge_size
# Calculate lengths and indices same as before
# calculate the length of the text and image tokens
text_length = next_vision_pos - current_pos
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# Text position ids
# text position ids
text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids)
# Vision position ids
# vision position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width
)
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.repeat_interleave(width)
.repeat(time_steps)
)
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
w_indices = torch.arange(width, device=d).repeat(
height * time_steps
)
vision_pos_ids = (
torch.stack([t_indices, h_indices, w_indices])
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Handle remaining text if any
if current_pos < batch_input_ids.size(1):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
video_pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
):
inputs_embeds = self.embed_tokens(input_ids)
if video_pixel_values is not None and len(video_pixel_values) > 0:
vision_embeds = self.visual(
video_pixel_values, grid_thw=video_grid_thw
).squeeze(0)
vision_token_mask = input_ids == self.video_token_id
inputs_embeds[vision_token_mask] = vision_embeds
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
vision_embeds = self.visual(
pixel_values,
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
pixel_values,
grid_thw=(
torch.cat([image_grid_thw, video_grid_thw])
if video_grid_thw is not None
else image_grid_thw
),
).squeeze(0)
# Apply embeddings to both image and video tokens
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
# Apply embeddings to image tokens
vision_token_mask = input_ids == self.image_token_id
inputs_embeds[vision_token_mask] = vision_embeds
hidden_states = self.text_model(

View File

@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
from torchvision import io
import math
tracer = trace.get_tracer(__name__)
@ -76,6 +78,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def video_text_replacement(processor, video_input, config) -> str:
if config.model_type == "qwen2_vl":
# num_pads = video_input['pixel_values'].size(0)
# num_pads = 1206
# import ipdb; ipdb.set_trace()
# num_pads = 9556 + 10
num_pads = video_input.pixel_values.shape[0] // 4
padding = "<|video_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
@ -138,29 +154,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def smart_nframes(
fps: int,
nframes: int,
min_frames: int,
max_frames: int,
total_frames: int,
video_fps: int | float,
) -> int:
if nframes:
nframes = round(nframes / 2) * 2
else:
min_frames = math.ceil(min_frames / 2) * 2
max_frames = math.floor(max_frames / 2) * 2
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round(nframes / 2) * 2
if not (2 <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
)
return nframes
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
video_pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
video_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.video_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.video_grid_thw = None
return batch
@classmethod
@ -171,6 +217,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
videos = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
@ -192,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image])
elif chunk_type == "video":
if config.model_type == "qwen2_vl":
pass
videos.append(chunk.video)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -201,6 +248,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
else:
image_inputs = None
video_inputs = None
if videos:
try:
tensor_videos = []
video = videos[0]
video_buffer = BytesIO(video.data)
video, _audio, info = io.read_video(
video_buffer,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
nframes = smart_nframes(
fps=30,
nframes=None,
min_frames=16,
max_frames=64,
total_frames=total_frames,
video_fps=video_fps,
)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
tensor_videos.append(video)
video_inputs = processor.image_processor(
tensor_videos, return_tensors="pt"
)
except Exception as e:
print(f"Failed to process video: {e}")
pass
else:
video_inputs = None
batch_inputs = []
max_truncation = 0
image_id = 0
@ -215,10 +297,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
processor, image_inputs, config, image_id
)
image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl":
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
text, _ = process_qwen_video(chunk.video)
full_text += text
elif chunk_type == "video":
full_text += video_text_replacement(processor, video_inputs, config)
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text)
@ -231,7 +311,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
return batch_tokenized_inputs, image_inputs
return batch_tokenized_inputs, image_inputs, video_inputs
@classmethod
def from_pb_processor(
@ -243,10 +323,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if video_inputs is not None:
if "pixel_values" in video_inputs:
batch.video_pixel_values = video_inputs["pixel_values"].to(
device=device
)
if "image_grid_thw" in video_inputs:
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.video_pixel_values = None
batch.video_grid_thw = None
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
@ -263,6 +356,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
if "video_grid_thw" in image_inputs:
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
@ -270,6 +367,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None
return batch
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
@ -372,7 +470,7 @@ class VlmCausalLM(FlashCausalLM):
if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
input_ids, batch.image_grid_thw, batch.video_grid_thw
)
batch.position_ids = position_ids
@ -425,20 +523,26 @@ class VlmCausalLM(FlashCausalLM):
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
video_pixel_values=batch.video_pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
video_grid_thw=batch.video_grid_thw,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.video_pixel_values is not None:
batch.video_pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
if batch.video_grid_thw is not None:
batch.video_grid_thw = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph