feat: support video input chunks and enable qwen2 vl to process video
This commit is contained in:
parent
6b4697e9d1
commit
a9c2d28a3a
|
@ -9,7 +9,7 @@ use thiserror::Error;
|
|||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||
pub use v3::{Chunk, Image, Input, InputChunk, Video};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
|
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
|
|||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
Some(Chunk::Video(url)) => {
|
||||
output.push_str(&format!("<video>({})", url))
|
||||
Some(Chunk::Video(Video { data, mimetype })) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
|
|
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens,
|
||||
StoppingCriteriaParameters, Tokens, Video,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
|
|
@ -15,7 +15,7 @@ pub use grpc_client::Client;
|
|||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
StoppingCriteriaParameters, Video,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
|
|
|
@ -439,7 +439,10 @@ impl State {
|
|||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
Chunk::Video(url) => client::Chunk::Video(url),
|
||||
Chunk::Video(video) => client::Chunk::Video(client::Video {
|
||||
data: video.data,
|
||||
mimetype: video.mimetype,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
|
|
|
@ -1926,6 +1926,24 @@
|
|||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"video_url",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"video_url"
|
||||
]
|
||||
},
|
||||
"video_url": {
|
||||
"$ref": "#/components/schemas/Url"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
|
|
@ -63,7 +63,6 @@ Options:
|
|||
|
||||
Possible values:
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
|
|
|
@ -1191,7 +1191,9 @@ impl From<Message> for TextMessage {
|
|||
.map(|chunk| match chunk {
|
||||
MessageChunk::Text { text } => text,
|
||||
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
|
||||
MessageChunk::VideoUrl { video_url } => format!("![]({})", video_url.url),
|
||||
MessageChunk::VideoUrl { video_url } => {
|
||||
format!("<video>({})", video_url.url)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
|
|
|
@ -516,6 +516,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||
.to_string()
|
||||
}
|
||||
|
||||
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
|
||||
let (data, mimetype) =
|
||||
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||
let url = &input["<video>(".len()..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
|
||||
(data, "video/mp4".to_string())
|
||||
} else if input.starts_with("<video>(data:") {
|
||||
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||
let tokens: Vec<_> = content.split(';').collect();
|
||||
if tokens.len() != 2 {
|
||||
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||
}
|
||||
let mimetype = tokens[0];
|
||||
let content = tokens[1];
|
||||
if !content.starts_with("base64,") {
|
||||
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||
}
|
||||
let data = STANDARD.decode(&content["base64,".len()..])?;
|
||||
(data, mimetype.to_string())
|
||||
} else {
|
||||
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||
};
|
||||
|
||||
let mut cursor = Cursor::new(&data);
|
||||
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
|
||||
|
||||
let video_track = context
|
||||
.tracks
|
||||
.iter()
|
||||
.find(|track| track.track_type == mp4parse::TrackType::Video)
|
||||
.ok_or(ValidationError::NoVideoStream)?;
|
||||
|
||||
let video_info = video_track
|
||||
.tkhd
|
||||
.as_ref()
|
||||
.ok_or(ValidationError::NoVideoStream)?;
|
||||
let width = (video_info.width >> 16) as usize;
|
||||
let height = (video_info.height >> 16) as usize;
|
||||
|
||||
// timescale units per second
|
||||
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
|
||||
|
||||
// TODO: revisit if we need duration in seconds
|
||||
let _duration = video_track
|
||||
.duration
|
||||
.map(|d| d.0 as f32 / timescale)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let time_to_sample = video_track
|
||||
.stts
|
||||
.as_ref()
|
||||
.ok_or(ValidationError::NoVideoStream)?;
|
||||
|
||||
let num_samples = time_to_sample
|
||||
.samples
|
||||
.iter()
|
||||
.map(|entry| entry.sample_count)
|
||||
.sum::<u32>();
|
||||
|
||||
let total_frames = num_samples as f32;
|
||||
|
||||
Ok((data, mimetype, height, width, total_frames))
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||
let url = &input["![](".len()..input.len() - 1];
|
||||
|
@ -604,6 +668,31 @@ fn image_tokens(
|
|||
}
|
||||
}
|
||||
|
||||
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
|
||||
use Config::*;
|
||||
|
||||
match config {
|
||||
// TOOD: improve to use the config to better estimate the number of tokens
|
||||
Qwen2Vl(_config) => {
|
||||
let video_fps = 30_f32;
|
||||
let fps = 30_f32;
|
||||
let min_frames = 16_f32;
|
||||
let max_frames = 64_f32;
|
||||
// make sure the frames are within the range and are even
|
||||
let nframes = (total_frames / video_fps * fps)
|
||||
.max(min_frames)
|
||||
.min(max_frames);
|
||||
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||
let num_tokens = nframes * height * width / 1541;
|
||||
format!(
|
||||
"<|vision_start|>{:?}<|vision_end|>",
|
||||
"<|video_pad|>".repeat(num_tokens)
|
||||
)
|
||||
}
|
||||
_ => unimplemented!("Video tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
match config {
|
||||
Config::Idefics2(_) => {
|
||||
|
@ -626,7 +715,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
// Add video regex
|
||||
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||
static VIDEO_RE: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(
|
||||
|
@ -636,7 +726,7 @@ fn prepare_input<T: TokenizerTrait>(
|
|||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
||||
// Process videos first
|
||||
// handle video content first
|
||||
for chunk in VIDEO_RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
|
@ -644,14 +734,15 @@ fn prepare_input<T: TokenizerTrait>(
|
|||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
|
||||
input_chunks.push(Chunk::Video(video_url.to_string()));
|
||||
// For videos, we use the default size as height/width don't matter for the initial processing
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
|
||||
let (data, mimetype, height, width, total_frames) =
|
||||
fetch_video(&inputs[chunk_start..chunk_end])?;
|
||||
input_chunks.push(Chunk::Video(Video { data, mimetype }));
|
||||
let video_tokens = video_tokens(config, height, width, total_frames);
|
||||
tokenizer_query.push_str(&video_tokens);
|
||||
start = chunk_end;
|
||||
}
|
||||
|
||||
// Process remaining content for images
|
||||
// clip remaining inputs and process images
|
||||
let remaining_input = &inputs[start..];
|
||||
for chunk in RE.find_iter(remaining_input) {
|
||||
let chunk_start = chunk.start() + start;
|
||||
|
@ -699,11 +790,17 @@ pub struct Image {
|
|||
pub mimetype: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Video {
|
||||
pub data: Vec<u8>,
|
||||
pub mimetype: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Chunk {
|
||||
Text(String),
|
||||
Image(Image),
|
||||
Video(String),
|
||||
Video(Video),
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
|
@ -722,8 +819,9 @@ impl ChunksToString for Vec<Chunk> {
|
|||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
Chunk::Video(url) => {
|
||||
output.push_str(&format!("<video>({})", url))
|
||||
Chunk::Video(Video { data, mimetype }) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
});
|
||||
output
|
||||
|
@ -851,6 +949,10 @@ pub enum ValidationError {
|
|||
UnsupportedModality(&'static str),
|
||||
#[error("invalid video content: {0}")]
|
||||
InvalidVideoContent(String),
|
||||
#[error("could not parse MP4 file")]
|
||||
MP4Error,
|
||||
#[error("no video stream found")]
|
||||
NoVideoStream,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -14,20 +14,12 @@
|
|||
# limitations under the License.
|
||||
"""PyTorch Qwen2 VL model."""
|
||||
|
||||
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
|
||||
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import requests
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from contextlib import contextmanager
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
@ -445,38 +437,48 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
|
||||
# Count both image and video tokens
|
||||
# only copy the sum of the image and video tokens GPU<->CPU
|
||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||
|
||||
current_pos = 0
|
||||
for _ in range(image_count + video_count):
|
||||
# Find next vision token position (either image or video)
|
||||
# copy the value position of the next image or video token from GPU<->CPU
|
||||
next_vision_pos = (
|
||||
((input_ids[current_pos:] == self.image_token_id) |
|
||||
(input_ids[current_pos:] == self.video_token_id))
|
||||
(
|
||||
(input_ids[current_pos:] == self.image_token_id)
|
||||
| (input_ids[current_pos:] == self.video_token_id)
|
||||
)
|
||||
.nonzero()[0]
|
||||
.item()
|
||||
)
|
||||
|
||||
# Determine if current token is video or image
|
||||
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
|
||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||
is_video = (
|
||||
input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||
)
|
||||
grid_thw = (
|
||||
video_grid_thw[vision_index]
|
||||
if is_video
|
||||
else image_grid_thw[vision_index]
|
||||
)
|
||||
|
||||
time_steps, height, width = grid_thw.clone()
|
||||
height //= self.spatial_merge_size
|
||||
width //= self.spatial_merge_size
|
||||
|
||||
# Calculate lengths and indices same as before
|
||||
# calculate the length of the text and image tokens
|
||||
text_length = next_vision_pos - current_pos
|
||||
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
start_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
)
|
||||
|
||||
# Text position ids
|
||||
# text position ids
|
||||
text_pos_ids = torch.arange(text_length, device=d)
|
||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||
llm_pos_ids_list.append(text_pos_ids)
|
||||
|
||||
# Vision position ids
|
||||
# vision position ids
|
||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||
height * width
|
||||
)
|
||||
|
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
.repeat_interleave(width)
|
||||
.repeat(time_steps)
|
||||
)
|
||||
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
|
||||
w_indices = torch.arange(width, device=d).repeat(
|
||||
height * time_steps
|
||||
)
|
||||
|
||||
vision_pos_ids = (
|
||||
torch.stack([t_indices, h_indices, w_indices])
|
||||
|
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
|
||||
# Handle remaining text if any
|
||||
if current_pos < batch_input_ids.size(1):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
text_len = batch_input_ids.size(1) - current_pos
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
||||
|
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
video_pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
|
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if video_pixel_values is not None and len(video_pixel_values) > 0:
|
||||
vision_embeds = self.visual(
|
||||
video_pixel_values, grid_thw=video_grid_thw
|
||||
).squeeze(0)
|
||||
vision_token_mask = input_ids == self.video_token_id
|
||||
inputs_embeds[vision_token_mask] = vision_embeds
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
vision_embeds = self.visual(
|
||||
pixel_values,
|
||||
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
|
||||
grid_thw=(
|
||||
torch.cat([image_grid_thw, video_grid_thw])
|
||||
if video_grid_thw is not None
|
||||
else image_grid_thw
|
||||
),
|
||||
).squeeze(0)
|
||||
|
||||
# Apply embeddings to both image and video tokens
|
||||
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
|
||||
# Apply embeddings to image tokens
|
||||
vision_token_mask = input_ids == self.image_token_id
|
||||
inputs_embeds[vision_token_mask] = vision_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
|
|
|
@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
|
|||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
from torchvision import io
|
||||
import math
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -76,6 +78,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def video_text_replacement(processor, video_input, config) -> str:
|
||||
if config.model_type == "qwen2_vl":
|
||||
# num_pads = video_input['pixel_values'].size(0)
|
||||
# num_pads = 1206
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# num_pads = 9556 + 10
|
||||
num_pads = video_input.pixel_values.shape[0] // 4
|
||||
padding = "<|video_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def image_text_replacement_fixup(config, text: str) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
return text.replace(
|
||||
|
@ -138,29 +154,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def smart_nframes(
|
||||
fps: int,
|
||||
nframes: int,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
video_fps: int | float,
|
||||
) -> int:
|
||||
if nframes:
|
||||
nframes = round(nframes / 2) * 2
|
||||
else:
|
||||
min_frames = math.ceil(min_frames / 2) * 2
|
||||
max_frames = math.floor(max_frames / 2) * 2
|
||||
nframes = total_frames / video_fps * fps
|
||||
nframes = min(max(nframes, min_frames), max_frames)
|
||||
nframes = round(nframes / 2) * 2
|
||||
if not (2 <= nframes and nframes <= total_frames):
|
||||
raise ValueError(
|
||||
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
|
||||
)
|
||||
return nframes
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
video_pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
video_grid_thw: Optional[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.video_pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.video_grid_thw = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.video_pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.video_grid_thw = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
|
@ -171,6 +217,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
# can make the image splits the same size. And we need the final
|
||||
# sizes to insert correct number of image tokens.
|
||||
images = []
|
||||
videos = []
|
||||
for r in requests:
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
|
@ -192,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
images.append([image])
|
||||
elif chunk_type == "video":
|
||||
if config.model_type == "qwen2_vl":
|
||||
pass
|
||||
videos.append(chunk.video)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
|
@ -201,6 +248,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
else:
|
||||
image_inputs = None
|
||||
|
||||
video_inputs = None
|
||||
if videos:
|
||||
try:
|
||||
tensor_videos = []
|
||||
video = videos[0]
|
||||
video_buffer = BytesIO(video.data)
|
||||
video, _audio, info = io.read_video(
|
||||
video_buffer,
|
||||
start_pts=0.0,
|
||||
end_pts=None,
|
||||
pts_unit="sec",
|
||||
output_format="TCHW",
|
||||
)
|
||||
total_frames, video_fps = video.size(0), info["video_fps"]
|
||||
nframes = smart_nframes(
|
||||
fps=30,
|
||||
nframes=None,
|
||||
min_frames=16,
|
||||
max_frames=64,
|
||||
total_frames=total_frames,
|
||||
video_fps=video_fps,
|
||||
)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
||||
video = video[idx]
|
||||
tensor_videos.append(video)
|
||||
video_inputs = processor.image_processor(
|
||||
tensor_videos, return_tensors="pt"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to process video: {e}")
|
||||
pass
|
||||
else:
|
||||
video_inputs = None
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
image_id = 0
|
||||
|
@ -215,10 +297,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
||||
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
|
||||
text, _ = process_qwen_video(chunk.video)
|
||||
full_text += text
|
||||
elif chunk_type == "video":
|
||||
full_text += video_text_replacement(processor, video_inputs, config)
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
batch_inputs.append(full_text)
|
||||
|
@ -231,7 +311,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
add_special_tokens=not config.model_type == "paligemma",
|
||||
)["input_ids"]
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
|
@ -243,10 +323,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if video_inputs is not None:
|
||||
if "pixel_values" in video_inputs:
|
||||
batch.video_pixel_values = video_inputs["pixel_values"].to(
|
||||
device=device
|
||||
)
|
||||
if "image_grid_thw" in video_inputs:
|
||||
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.video_grid_thw = None
|
||||
else:
|
||||
batch.video_pixel_values = None
|
||||
batch.video_grid_thw = None
|
||||
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
|
@ -263,6 +356,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
if "video_grid_thw" in image_inputs:
|
||||
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.video_grid_thw = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
|
@ -270,6 +367,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -372,7 +470,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||
if self.model.config.model_type == "qwen2_vl":
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
input_ids, batch.image_grid_thw, batch.video_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
|
@ -425,20 +523,26 @@ class VlmCausalLM(FlashCausalLM):
|
|||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
video_pixel_values=batch.video_pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
video_grid_thw=batch.video_grid_thw,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.video_pixel_values is not None:
|
||||
batch.video_pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
if batch.video_grid_thw is not None:
|
||||
batch.video_grid_thw = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
|
Loading…
Reference in New Issue