feat: support video input chunks and enable qwen2 vl to process video

This commit is contained in:
David Holtz 2024-11-18 21:16:21 +00:00
parent 6b4697e9d1
commit a9c2d28a3a
10 changed files with 310 additions and 63 deletions

View File

@ -9,7 +9,7 @@ use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
pub use v3::{Chunk, Image, Input, InputChunk, Video};
#[async_trait]
pub trait Health {
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Some(Chunk::Video(url)) => {
output.push_str(&format!("<video>({})", url))
Some(Chunk::Video(Video { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
StoppingCriteriaParameters, Tokens, Video,
};
pub use sharded_client::ShardedClient;

View File

@ -15,7 +15,7 @@ pub use grpc_client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
StoppingCriteriaParameters, Video,
};
pub use sharded_client::ShardedClient;

View File

@ -439,7 +439,10 @@ impl State {
data: image.data,
mimetype: image.mimetype,
}),
Chunk::Video(url) => client::Chunk::Video(url),
Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.data,
mimetype: video.mimetype,
}),
}),
})
.collect(),

View File

@ -1926,6 +1926,24 @@
]
}
}
},
{
"type": "object",
"required": [
"video_url",
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"video_url"
]
},
"video_url": {
"$ref": "#/components/schemas/Url"
}
}
}
],
"discriminator": {

View File

@ -63,7 +63,6 @@ Options:
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels

View File

@ -1191,7 +1191,9 @@ impl From<Message> for TextMessage {
.map(|chunk| match chunk {
MessageChunk::Text { text } => text,
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
MessageChunk::VideoUrl { video_url } => format!("![]({})", video_url.url),
MessageChunk::VideoUrl { video_url } => {
format!("<video>({})", video_url.url)
}
})
.collect::<Vec<_>>()
.join(""),

View File

@ -516,6 +516,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
let (data, mimetype) =
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
(data, "video/mp4".to_string())
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let mimetype = tokens[0];
let content = tokens[1];
if !content.starts_with("base64,") {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
let data = STANDARD.decode(&content["base64,".len()..])?;
(data, mimetype.to_string())
} else {
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};
let mut cursor = Cursor::new(&data);
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
let video_track = context
.tracks
.iter()
.find(|track| track.track_type == mp4parse::TrackType::Video)
.ok_or(ValidationError::NoVideoStream)?;
let video_info = video_track
.tkhd
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let width = (video_info.width >> 16) as usize;
let height = (video_info.height >> 16) as usize;
// timescale units per second
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
// TODO: revisit if we need duration in seconds
let _duration = video_track
.duration
.map(|d| d.0 as f32 / timescale)
.unwrap_or(0.0);
let time_to_sample = video_track
.stts
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let num_samples = time_to_sample
.samples
.iter()
.map(|entry| entry.sample_count)
.sum::<u32>();
let total_frames = num_samples as f32;
Ok((data, mimetype, height, width, total_frames))
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
@ -604,6 +668,31 @@ fn image_tokens(
}
}
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
use Config::*;
match config {
// TOOD: improve to use the config to better estimate the number of tokens
Qwen2Vl(_config) => {
let video_fps = 30_f32;
let fps = 30_f32;
let min_frames = 16_f32;
let max_frames = 64_f32;
// make sure the frames are within the range and are even
let nframes = (total_frames / video_fps * fps)
.max(min_frames)
.min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height * width / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
)
}
_ => unimplemented!("Video tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String {
match config {
Config::Idefics2(_) => {
@ -626,7 +715,8 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
// Add video regex
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
static VIDEO_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(
@ -636,7 +726,7 @@ fn prepare_input<T: TokenizerTrait>(
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
// Process videos first
// handle video content first
for chunk in VIDEO_RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
@ -644,14 +734,15 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
input_chunks.push(Chunk::Video(video_url.to_string()));
// For videos, we use the default size as height/width don't matter for the initial processing
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
let (data, mimetype, height, width, total_frames) =
fetch_video(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Video(Video { data, mimetype }));
let video_tokens = video_tokens(config, height, width, total_frames);
tokenizer_query.push_str(&video_tokens);
start = chunk_end;
}
// Process remaining content for images
// clip remaining inputs and process images
let remaining_input = &inputs[start..];
for chunk in RE.find_iter(remaining_input) {
let chunk_start = chunk.start() + start;
@ -699,11 +790,17 @@ pub struct Image {
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image),
Video(String),
Video(Video),
}
/// Convert input chunks to a stringly-typed input for backwards
@ -722,8 +819,9 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Chunk::Video(url) => {
output.push_str(&format!("<video>({})", url))
Chunk::Video(Video { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
}
});
output
@ -851,6 +949,10 @@ pub enum ValidationError {
UnsupportedModality(&'static str),
#[error("invalid video content: {0}")]
InvalidVideoContent(String),
#[error("could not parse MP4 file")]
MP4Error,
#[error("no video stream found")]
NoVideoStream,
}
#[cfg(test)]

View File

@ -14,20 +14,12 @@
# limitations under the License.
"""PyTorch Qwen2 VL model."""
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
@ -445,38 +437,48 @@ class Qwen2VLForConditionalGeneration(nn.Module):
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# Count both image and video tokens
# only copy the sum of the image and video tokens GPU<->CPU
image_count = (vision_tokens == self.image_token_id).sum().item()
video_count = (vision_tokens == self.video_token_id).sum().item()
current_pos = 0
for _ in range(image_count + video_count):
# Find next vision token position (either image or video)
# copy the value position of the next image or video token from GPU<->CPU
next_vision_pos = (
((input_ids[current_pos:] == self.image_token_id) |
(input_ids[current_pos:] == self.video_token_id))
(
(input_ids[current_pos:] == self.image_token_id)
| (input_ids[current_pos:] == self.video_token_id)
)
.nonzero()[0]
.item()
)
# Determine if current token is video or image
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
is_video = (
input_ids[current_pos + next_vision_pos] == self.video_token_id
)
grid_thw = (
video_grid_thw[vision_index]
if is_video
else image_grid_thw[vision_index]
)
time_steps, height, width = grid_thw.clone()
height //= self.spatial_merge_size
width //= self.spatial_merge_size
# Calculate lengths and indices same as before
# calculate the length of the text and image tokens
text_length = next_vision_pos - current_pos
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
start_idx = (
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# Text position ids
# text position ids
text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids)
# Vision position ids
# vision position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width
)
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.repeat_interleave(width)
.repeat(time_steps)
)
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
w_indices = torch.arange(width, device=d).repeat(
height * time_steps
)
vision_pos_ids = (
torch.stack([t_indices, h_indices, w_indices])
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Handle remaining text if any
if current_pos < batch_input_ids.size(1):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
video_pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
):
inputs_embeds = self.embed_tokens(input_ids)
if video_pixel_values is not None and len(video_pixel_values) > 0:
vision_embeds = self.visual(
video_pixel_values, grid_thw=video_grid_thw
).squeeze(0)
vision_token_mask = input_ids == self.video_token_id
inputs_embeds[vision_token_mask] = vision_embeds
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
vision_embeds = self.visual(
pixel_values,
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
grid_thw=(
torch.cat([image_grid_thw, video_grid_thw])
if video_grid_thw is not None
else image_grid_thw
),
).squeeze(0)
# Apply embeddings to both image and video tokens
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
# Apply embeddings to image tokens
vision_token_mask = input_ids == self.image_token_id
inputs_embeds[vision_token_mask] = vision_embeds
hidden_states = self.text_model(

View File

@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
from torchvision import io
import math
tracer = trace.get_tracer(__name__)
@ -76,6 +78,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def video_text_replacement(processor, video_input, config) -> str:
if config.model_type == "qwen2_vl":
# num_pads = video_input['pixel_values'].size(0)
# num_pads = 1206
# import ipdb; ipdb.set_trace()
# num_pads = 9556 + 10
num_pads = video_input.pixel_values.shape[0] // 4
padding = "<|video_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
@ -138,29 +154,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def smart_nframes(
fps: int,
nframes: int,
min_frames: int,
max_frames: int,
total_frames: int,
video_fps: int | float,
) -> int:
if nframes:
nframes = round(nframes / 2) * 2
else:
min_frames = math.ceil(min_frames / 2) * 2
max_frames = math.floor(max_frames / 2) * 2
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round(nframes / 2) * 2
if not (2 <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
)
return nframes
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
video_pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
video_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.video_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.video_pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.video_grid_thw = None
return batch
@classmethod
@ -171,6 +217,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
videos = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
@ -192,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image])
elif chunk_type == "video":
if config.model_type == "qwen2_vl":
pass
videos.append(chunk.video)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -201,6 +248,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
else:
image_inputs = None
video_inputs = None
if videos:
try:
tensor_videos = []
video = videos[0]
video_buffer = BytesIO(video.data)
video, _audio, info = io.read_video(
video_buffer,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
nframes = smart_nframes(
fps=30,
nframes=None,
min_frames=16,
max_frames=64,
total_frames=total_frames,
video_fps=video_fps,
)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
tensor_videos.append(video)
video_inputs = processor.image_processor(
tensor_videos, return_tensors="pt"
)
except Exception as e:
print(f"Failed to process video: {e}")
pass
else:
video_inputs = None
batch_inputs = []
max_truncation = 0
image_id = 0
@ -215,10 +297,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
processor, image_inputs, config, image_id
)
image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl":
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
text, _ = process_qwen_video(chunk.video)
full_text += text
elif chunk_type == "video":
full_text += video_text_replacement(processor, video_inputs, config)
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text)
@ -231,7 +311,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
return batch_tokenized_inputs, image_inputs
return batch_tokenized_inputs, image_inputs, video_inputs
@classmethod
def from_pb_processor(
@ -243,10 +323,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if video_inputs is not None:
if "pixel_values" in video_inputs:
batch.video_pixel_values = video_inputs["pixel_values"].to(
device=device
)
if "image_grid_thw" in video_inputs:
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.video_pixel_values = None
batch.video_grid_thw = None
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
@ -263,6 +356,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
if "video_grid_thw" in image_inputs:
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
else:
batch.video_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
@ -270,6 +367,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None
return batch
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
@ -372,7 +470,7 @@ class VlmCausalLM(FlashCausalLM):
if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
input_ids, batch.image_grid_thw, batch.video_grid_thw
)
batch.position_ids = position_ids
@ -425,20 +523,26 @@ class VlmCausalLM(FlashCausalLM):
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
video_pixel_values=batch.video_pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
video_grid_thw=batch.video_grid_thw,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.video_pixel_values is not None:
batch.video_pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
if batch.video_grid_thw is not None:
batch.video_grid_thw = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph