feat: support video input chunks and enable qwen2 vl to process video

This commit is contained in:
David Holtz 2024-11-18 21:16:21 +00:00
parent 6b4697e9d1
commit a9c2d28a3a
10 changed files with 310 additions and 63 deletions

View File

@ -9,7 +9,7 @@ use thiserror::Error;
use tonic::transport; use tonic::transport;
use tonic::Status; use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk}; pub use v3::{Chunk, Image, Input, InputChunk, Video};
#[async_trait] #[async_trait]
pub trait Health { pub trait Health {
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Some(Chunk::Video(url)) => { Some(Chunk::Video(Video { data, mimetype })) => {
output.push_str(&format!("<video>({})", url)) let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
} }
// We don't create empty chunks, so this should be unreachable. // We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"), None => unreachable!("Chunks should never be empty"),

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens, StoppingCriteriaParameters, Tokens, Video,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -15,7 +15,7 @@ pub use grpc_client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, StoppingCriteriaParameters, Video,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -439,7 +439,10 @@ impl State {
data: image.data, data: image.data,
mimetype: image.mimetype, mimetype: image.mimetype,
}), }),
Chunk::Video(url) => client::Chunk::Video(url), Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.data,
mimetype: video.mimetype,
}),
}), }),
}) })
.collect(), .collect(),

View File

@ -1926,6 +1926,24 @@
] ]
} }
} }
},
{
"type": "object",
"required": [
"video_url",
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"video_url"
]
},
"video_url": {
"$ref": "#/components/schemas/Url"
}
}
} }
], ],
"discriminator": { "discriminator": {

View File

@ -62,16 +62,15 @@ Options:
[env: QUANTIZE=] [env: QUANTIZE=]
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin> - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
``` ```
## SPECULATE ## SPECULATE

View File

@ -1191,7 +1191,9 @@ impl From<Message> for TextMessage {
.map(|chunk| match chunk { .map(|chunk| match chunk {
MessageChunk::Text { text } => text, MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
MessageChunk::VideoUrl { video_url } => format!("![]({})", video_url.url), MessageChunk::VideoUrl { video_url } => {
format!("<video>({})", video_url.url)
}
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""), .join(""),

View File

@ -516,6 +516,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string() .to_string()
} }
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
let (data, mimetype) =
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
(data, "video/mp4".to_string())
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let data = STANDARD.decode(&content["base64,".len()..])?;
(data, mimetype.to_string())
} else {
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};
let mut cursor = Cursor::new(&data);
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
let video_track = context
.tracks
.iter()
.find(|track| track.track_type == mp4parse::TrackType::Video)
.ok_or(ValidationError::NoVideoStream)?;
let video_info = video_track
.tkhd
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let width = (video_info.width >> 16) as usize;
let height = (video_info.height >> 16) as usize;
// timescale units per second
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
// TODO: revisit if we need duration in seconds
let _duration = video_track
.duration
.map(|d| d.0 as f32 / timescale)
.unwrap_or(0.0);
let time_to_sample = video_track
.stts
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let num_samples = time_to_sample
.samples
.iter()
.map(|entry| entry.sample_count)
.sum::<u32>();
let total_frames = num_samples as f32;
Ok((data, mimetype, height, width, total_frames))
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> { fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") { if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1]; let url = &input["![](".len()..input.len() - 1];
@ -604,6 +668,31 @@ fn image_tokens(
} }
} }
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
use Config::*;
match config {
// TOOD: improve to use the config to better estimate the number of tokens
Qwen2Vl(_config) => {
let video_fps = 30_f32;
let fps = 30_f32;
let min_frames = 16_f32;
let max_frames = 64_f32;
// make sure the frames are within the range and are even
let nframes = (total_frames / video_fps * fps)
.max(min_frames)
.min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height * width / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
)
}
_ => unimplemented!("Video tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String { fn image_tokens_fixup(config: &Config, text: String) -> String {
match config { match config {
Config::Idefics2(_) => { Config::Idefics2(_) => {
@ -626,7 +715,8 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
// Add video regex // Add video regex
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap()); static VIDEO_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
@ -636,7 +726,7 @@ fn prepare_input<T: TokenizerTrait>(
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
// Process videos first // handle video content first
for chunk in VIDEO_RE.find_iter(&inputs) { for chunk in VIDEO_RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
@ -644,14 +734,15 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and ) let (data, mimetype, height, width, total_frames) =
input_chunks.push(Chunk::Video(video_url.to_string())); fetch_video(&inputs[chunk_start..chunk_end])?;
// For videos, we use the default size as height/width don't matter for the initial processing input_chunks.push(Chunk::Video(Video { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1)); let video_tokens = video_tokens(config, height, width, total_frames);
tokenizer_query.push_str(&video_tokens);
start = chunk_end; start = chunk_end;
} }
// Process remaining content for images // clip remaining inputs and process images
let remaining_input = &inputs[start..]; let remaining_input = &inputs[start..];
for chunk in RE.find_iter(remaining_input) { for chunk in RE.find_iter(remaining_input) {
let chunk_start = chunk.start() + start; let chunk_start = chunk.start() + start;
@ -699,11 +790,17 @@ pub struct Image {
pub mimetype: String, pub mimetype: String,
} }
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk { pub enum Chunk {
Text(String), Text(String),
Image(Image), Image(Image),
Video(String), Video(Video),
} }
/// Convert input chunks to a stringly-typed input for backwards /// Convert input chunks to a stringly-typed input for backwards
@ -722,8 +819,9 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Chunk::Video(url) => { Chunk::Video(Video { data, mimetype }) => {
output.push_str(&format!("<video>({})", url)) let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
} }
}); });
output output
@ -851,6 +949,10 @@ pub enum ValidationError {
UnsupportedModality(&'static str), UnsupportedModality(&'static str),
#[error("invalid video content: {0}")] #[error("invalid video content: {0}")]
InvalidVideoContent(String), InvalidVideoContent(String),
#[error("could not parse MP4 file")]
MP4Error,
#[error("no video stream found")]
NoVideoStream,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -14,20 +14,12 @@
# limitations under the License. # limitations under the License.
"""PyTorch Qwen2 VL model.""" """PyTorch Qwen2 VL model."""
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
from typing import Dict, Optional, Tuple, List from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info
if SYSTEM == "ipex": if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -445,38 +437,48 @@ class Qwen2VLForConditionalGeneration(nn.Module):
).squeeze(1) ).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1] vision_tokens = input_ids[vision_start_indices + 1]
# Count both image and video tokens # only copy the sum of the image and video tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item() image_count = (vision_tokens == self.image_token_id).sum().item()
video_count = (vision_tokens == self.video_token_id).sum().item() video_count = (vision_tokens == self.video_token_id).sum().item()
current_pos = 0 current_pos = 0
for _ in range(image_count + video_count): for _ in range(image_count + video_count):
# Find next vision token position (either image or video) # copy the value position of the next image or video token from GPU<->CPU
next_vision_pos = ( next_vision_pos = (
((input_ids[current_pos:] == self.image_token_id) | (
(input_ids[current_pos:] == self.video_token_id)) (input_ids[current_pos:] == self.image_token_id)
| (input_ids[current_pos:] == self.video_token_id)
)
.nonzero()[0] .nonzero()[0]
.item() .item()
) )
# Determine if current token is video or image # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id is_video = (
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index] input_ids[current_pos + next_vision_pos] == self.video_token_id
)
grid_thw = (
video_grid_thw[vision_index]
if is_video
else image_grid_thw[vision_index]
)
time_steps, height, width = grid_thw.clone() time_steps, height, width = grid_thw.clone()
height //= self.spatial_merge_size height //= self.spatial_merge_size
width //= self.spatial_merge_size width //= self.spatial_merge_size
# Calculate lengths and indices same as before # calculate the length of the text and image tokens
text_length = next_vision_pos - current_pos text_length = next_vision_pos - current_pos
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# Text position ids # text position ids
text_pos_ids = torch.arange(text_length, device=d) text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids) llm_pos_ids_list.append(text_pos_ids)
# Vision position ids # vision position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave( t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width height * width
) )
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.repeat_interleave(width) .repeat_interleave(width)
.repeat(time_steps) .repeat(time_steps)
) )
w_indices = torch.arange(width, device=d).repeat(height * time_steps) w_indices = torch.arange(width, device=d).repeat(
height * time_steps
)
vision_pos_ids = ( vision_pos_ids = (
torch.stack([t_indices, h_indices, w_indices]) torch.stack([t_indices, h_indices, w_indices])
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Handle remaining text if any # Handle remaining text if any
if current_pos < batch_input_ids.size(1): if current_pos < batch_input_ids.size(1):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
video_pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None, pixel_attention_mask=None,
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
): ):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if video_pixel_values is not None and len(video_pixel_values) > 0:
vision_embeds = self.visual(
video_pixel_values, grid_thw=video_grid_thw
).squeeze(0)
vision_token_mask = input_ids == self.video_token_id
inputs_embeds[vision_token_mask] = vision_embeds
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
vision_embeds = self.visual( vision_embeds = self.visual(
pixel_values, pixel_values,
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw grid_thw=(
torch.cat([image_grid_thw, video_grid_thw])
if video_grid_thw is not None
else image_grid_thw
),
).squeeze(0) ).squeeze(0)
# Apply embeddings to both image and video tokens # Apply embeddings to image tokens
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id) vision_token_mask = input_ids == self.image_token_id
inputs_embeds[vision_token_mask] = vision_embeds inputs_embeds[vision_token_mask] = vision_embeds
hidden_states = self.text_model( hidden_states = self.text_model(

View File

@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged from text_generation_server.models.metadata_kernels import block_tables_to_ragged
from torchvision import io
import math
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -76,6 +78,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def video_text_replacement(processor, video_input, config) -> str:
if config.model_type == "qwen2_vl":
# num_pads = video_input['pixel_values'].size(0)
# num_pads = 1206
# import ipdb; ipdb.set_trace()
# num_pads = 9556 + 10
num_pads = video_input.pixel_values.shape[0] // 4
padding = "<|video_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str: def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
return text.replace( return text.replace(
@ -138,29 +154,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def smart_nframes(
fps: int,
nframes: int,
min_frames: int,
max_frames: int,
total_frames: int,
video_fps: int | float,
) -> int:
if nframes:
nframes = round(nframes / 2) * 2
else:
min_frames = math.ceil(min_frames / 2) * 2
max_frames = math.floor(max_frames / 2) * 2
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round(nframes / 2) * 2
if not (2 <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
)
return nframes
class VlmCausalLMBatch(FlashCausalLMBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
video_pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor] image_grid_thw: Optional[torch.Tensor]
video_grid_thw: Optional[torch.Tensor]
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.video_grid_thw = None
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
batch.video_grid_thw = None
return batch return batch
@classmethod @classmethod
@ -171,6 +217,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# can make the image splits the same size. And we need the final # can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens. # sizes to insert correct number of image tokens.
images = [] images = []
videos = []
for r in requests: for r in requests:
for chunk in r.input_chunks.chunks: for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk") chunk_type = chunk.WhichOneof("chunk")
@ -192,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image]) images.append([image])
elif chunk_type == "video": elif chunk_type == "video":
if config.model_type == "qwen2_vl": if config.model_type == "qwen2_vl":
pass videos.append(chunk.video)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -201,6 +248,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
video_inputs = None
if videos:
try:
tensor_videos = []
video = videos[0]
video_buffer = BytesIO(video.data)
video, _audio, info = io.read_video(
video_buffer,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
nframes = smart_nframes(
fps=30,
nframes=None,
min_frames=16,
max_frames=64,
total_frames=total_frames,
video_fps=video_fps,
)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
tensor_videos.append(video)
video_inputs = processor.image_processor(
tensor_videos, return_tensors="pt"
)
except Exception as e:
print(f"Failed to process video: {e}")
pass
else:
video_inputs = None
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
image_id = 0 image_id = 0
@ -215,10 +297,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
processor, image_inputs, config, image_id processor, image_inputs, config, image_id
) )
image_id += 1 image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl": elif chunk_type == "video":
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video full_text += video_text_replacement(processor, video_inputs, config)
text, _ = process_qwen_video(chunk.video)
full_text += text
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
@ -231,7 +311,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
add_special_tokens=not config.model_type == "paligemma", add_special_tokens=not config.model_type == "paligemma",
)["input_ids"] )["input_ids"]
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs, video_inputs
@classmethod @classmethod
def from_pb_processor( def from_pb_processor(
@ -243,10 +323,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "VlmCausalLMBatch": ) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config pb.requests, tokenizer, processor, config
) )
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if video_inputs is not None:
if "pixel_values" in video_inputs:
batch.video_pixel_values = video_inputs["pixel_values"].to(
device=device
)
if "image_grid_thw" in video_inputs:
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.video_pixel_values = None
batch.video_grid_thw = None
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs: if "pixel_attention_mask" in image_inputs:
@ -263,6 +356,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else: else:
batch.image_grid_thw = None batch.image_grid_thw = None
if "video_grid_thw" in image_inputs:
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
@ -270,6 +367,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(
self, self,
@ -372,7 +470,7 @@ class VlmCausalLM(FlashCausalLM):
if self.model.config.model_type == "qwen2_vl": if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1 and batch.prefilling: if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw input_ids, batch.image_grid_thw, batch.video_grid_thw
) )
batch.position_ids = position_ids batch.position_ids = position_ids
@ -425,20 +523,26 @@ class VlmCausalLM(FlashCausalLM):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
video_pixel_values=batch.video_pixel_values,
pixel_attention_mask=batch.pixel_attention_mask, pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw, image_grid_thw=batch.image_grid_thw,
video_grid_thw=batch.video_grid_thw,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.video_pixel_values is not None:
batch.video_pixel_values = None
if batch.pixel_attention_mask is not None: if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
if batch.image_grid_thw is not None: if batch.image_grid_thw is not None:
batch.image_grid_thw = None batch.image_grid_thw = None
if batch.video_grid_thw is not None:
batch.video_grid_thw = None
return logits, speculative_logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph