feat: support video input chunks and enable qwen2 vl to process video
This commit is contained in:
parent
6b4697e9d1
commit
a9c2d28a3a
|
@ -9,7 +9,7 @@ use thiserror::Error;
|
||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
pub use v3::{Chunk, Image, Input, InputChunk, Video};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Health {
|
pub trait Health {
|
||||||
|
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
|
||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
Some(Chunk::Video(url)) => {
|
Some(Chunk::Video(Video { data, mimetype })) => {
|
||||||
output.push_str(&format!("<video>({})", url))
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
// We don't create empty chunks, so this should be unreachable.
|
// We don't create empty chunks, so this should be unreachable.
|
||||||
None => unreachable!("Chunks should never be empty"),
|
None => unreachable!("Chunks should never be empty"),
|
||||||
|
|
|
@ -8,6 +8,6 @@ pub use client::Client;
|
||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters, Tokens,
|
StoppingCriteriaParameters, Tokens, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
|
@ -15,7 +15,7 @@ pub use grpc_client::Client;
|
||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
|
|
@ -439,7 +439,10 @@ impl State {
|
||||||
data: image.data,
|
data: image.data,
|
||||||
mimetype: image.mimetype,
|
mimetype: image.mimetype,
|
||||||
}),
|
}),
|
||||||
Chunk::Video(url) => client::Chunk::Video(url),
|
Chunk::Video(video) => client::Chunk::Video(client::Video {
|
||||||
|
data: video.data,
|
||||||
|
mimetype: video.mimetype,
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
|
@ -1926,6 +1926,24 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"video_url",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"video_url"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"video_url": {
|
||||||
|
"$ref": "#/components/schemas/Url"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
|
|
@ -62,16 +62,15 @@ Options:
|
||||||
[env: QUANTIZE=]
|
[env: QUANTIZE=]
|
||||||
|
|
||||||
Possible values:
|
Possible values:
|
||||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||||
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods
|
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||||
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
|
||||||
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
|
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
|
||||||
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
|
|
||||||
|
|
||||||
```
|
```
|
||||||
## SPECULATE
|
## SPECULATE
|
||||||
|
|
|
@ -1191,7 +1191,9 @@ impl From<Message> for TextMessage {
|
||||||
.map(|chunk| match chunk {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text { text } => text,
|
MessageChunk::Text { text } => text,
|
||||||
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
|
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
|
||||||
MessageChunk::VideoUrl { video_url } => format!("![]({})", video_url.url),
|
MessageChunk::VideoUrl { video_url } => {
|
||||||
|
format!("<video>({})", video_url.url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
|
|
|
@ -516,6 +516,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
|
||||||
|
let (data, mimetype) =
|
||||||
|
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||||
|
let url = &input["<video>(".len()..input.len() - 1];
|
||||||
|
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
|
||||||
|
(data, "video/mp4".to_string())
|
||||||
|
} else if input.starts_with("<video>(data:") {
|
||||||
|
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||||
|
let tokens: Vec<_> = content.split(';').collect();
|
||||||
|
if tokens.len() != 2 {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let mimetype = tokens[0];
|
||||||
|
let content = tokens[1];
|
||||||
|
if !content.starts_with("base64,") {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let data = STANDARD.decode(&content["base64,".len()..])?;
|
||||||
|
(data, mimetype.to_string())
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&data);
|
||||||
|
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
|
||||||
|
|
||||||
|
let video_track = context
|
||||||
|
.tracks
|
||||||
|
.iter()
|
||||||
|
.find(|track| track.track_type == mp4parse::TrackType::Video)
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
|
||||||
|
let video_info = video_track
|
||||||
|
.tkhd
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
let width = (video_info.width >> 16) as usize;
|
||||||
|
let height = (video_info.height >> 16) as usize;
|
||||||
|
|
||||||
|
// timescale units per second
|
||||||
|
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
|
||||||
|
|
||||||
|
// TODO: revisit if we need duration in seconds
|
||||||
|
let _duration = video_track
|
||||||
|
.duration
|
||||||
|
.map(|d| d.0 as f32 / timescale)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
let time_to_sample = video_track
|
||||||
|
.stts
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
|
||||||
|
let num_samples = time_to_sample
|
||||||
|
.samples
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.sample_count)
|
||||||
|
.sum::<u32>();
|
||||||
|
|
||||||
|
let total_frames = num_samples as f32;
|
||||||
|
|
||||||
|
Ok((data, mimetype, height, width, total_frames))
|
||||||
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||||
let url = &input["![](".len()..input.len() - 1];
|
let url = &input["![](".len()..input.len() - 1];
|
||||||
|
@ -604,6 +668,31 @@ fn image_tokens(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
|
||||||
|
use Config::*;
|
||||||
|
|
||||||
|
match config {
|
||||||
|
// TOOD: improve to use the config to better estimate the number of tokens
|
||||||
|
Qwen2Vl(_config) => {
|
||||||
|
let video_fps = 30_f32;
|
||||||
|
let fps = 30_f32;
|
||||||
|
let min_frames = 16_f32;
|
||||||
|
let max_frames = 64_f32;
|
||||||
|
// make sure the frames are within the range and are even
|
||||||
|
let nframes = (total_frames / video_fps * fps)
|
||||||
|
.max(min_frames)
|
||||||
|
.min(max_frames);
|
||||||
|
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||||
|
let num_tokens = nframes * height * width / 1541;
|
||||||
|
format!(
|
||||||
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
|
"<|video_pad|>".repeat(num_tokens)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unimplemented!("Video tokens are not supported for this model configuration"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||||
match config {
|
match config {
|
||||||
Config::Idefics2(_) => {
|
Config::Idefics2(_) => {
|
||||||
|
@ -626,7 +715,8 @@ fn prepare_input<T: TokenizerTrait>(
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
// Add video regex
|
// Add video regex
|
||||||
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
static VIDEO_RE: Lazy<Regex> =
|
||||||
|
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||||
|
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
|
@ -636,7 +726,7 @@ fn prepare_input<T: TokenizerTrait>(
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
|
||||||
// Process videos first
|
// handle video content first
|
||||||
for chunk in VIDEO_RE.find_iter(&inputs) {
|
for chunk in VIDEO_RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
|
@ -644,14 +734,15 @@ fn prepare_input<T: TokenizerTrait>(
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
|
let (data, mimetype, height, width, total_frames) =
|
||||||
input_chunks.push(Chunk::Video(video_url.to_string()));
|
fetch_video(&inputs[chunk_start..chunk_end])?;
|
||||||
// For videos, we use the default size as height/width don't matter for the initial processing
|
input_chunks.push(Chunk::Video(Video { data, mimetype }));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
|
let video_tokens = video_tokens(config, height, width, total_frames);
|
||||||
|
tokenizer_query.push_str(&video_tokens);
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process remaining content for images
|
// clip remaining inputs and process images
|
||||||
let remaining_input = &inputs[start..];
|
let remaining_input = &inputs[start..];
|
||||||
for chunk in RE.find_iter(remaining_input) {
|
for chunk in RE.find_iter(remaining_input) {
|
||||||
let chunk_start = chunk.start() + start;
|
let chunk_start = chunk.start() + start;
|
||||||
|
@ -699,11 +790,17 @@ pub struct Image {
|
||||||
pub mimetype: String,
|
pub mimetype: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Video {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub mimetype: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub enum Chunk {
|
pub enum Chunk {
|
||||||
Text(String),
|
Text(String),
|
||||||
Image(Image),
|
Image(Image),
|
||||||
Video(String),
|
Video(Video),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert input chunks to a stringly-typed input for backwards
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
|
@ -722,8 +819,9 @@ impl ChunksToString for Vec<Chunk> {
|
||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
Chunk::Video(url) => {
|
Chunk::Video(Video { data, mimetype }) => {
|
||||||
output.push_str(&format!("<video>({})", url))
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
output
|
output
|
||||||
|
@ -851,6 +949,10 @@ pub enum ValidationError {
|
||||||
UnsupportedModality(&'static str),
|
UnsupportedModality(&'static str),
|
||||||
#[error("invalid video content: {0}")]
|
#[error("invalid video content: {0}")]
|
||||||
InvalidVideoContent(String),
|
InvalidVideoContent(String),
|
||||||
|
#[error("could not parse MP4 file")]
|
||||||
|
MP4Error,
|
||||||
|
#[error("no video stream found")]
|
||||||
|
NoVideoStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -14,20 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Qwen2 VL model."""
|
"""PyTorch Qwen2 VL model."""
|
||||||
|
|
||||||
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
|
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from contextlib import contextmanager
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -444,39 +436,49 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
input_ids == self.vision_start_token_id
|
input_ids == self.vision_start_token_id
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
|
|
||||||
# Count both image and video tokens
|
# only copy the sum of the image and video tokens GPU<->CPU
|
||||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||||
video_count = (vision_tokens == self.video_token_id).sum().item()
|
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||||
|
|
||||||
current_pos = 0
|
current_pos = 0
|
||||||
for _ in range(image_count + video_count):
|
for _ in range(image_count + video_count):
|
||||||
# Find next vision token position (either image or video)
|
# copy the value position of the next image or video token from GPU<->CPU
|
||||||
next_vision_pos = (
|
next_vision_pos = (
|
||||||
((input_ids[current_pos:] == self.image_token_id) |
|
(
|
||||||
(input_ids[current_pos:] == self.video_token_id))
|
(input_ids[current_pos:] == self.image_token_id)
|
||||||
|
| (input_ids[current_pos:] == self.video_token_id)
|
||||||
|
)
|
||||||
.nonzero()[0]
|
.nonzero()[0]
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine if current token is video or image
|
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||||
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
|
is_video = (
|
||||||
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
|
input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||||
|
)
|
||||||
|
grid_thw = (
|
||||||
|
video_grid_thw[vision_index]
|
||||||
|
if is_video
|
||||||
|
else image_grid_thw[vision_index]
|
||||||
|
)
|
||||||
|
|
||||||
time_steps, height, width = grid_thw.clone()
|
time_steps, height, width = grid_thw.clone()
|
||||||
height //= self.spatial_merge_size
|
height //= self.spatial_merge_size
|
||||||
width //= self.spatial_merge_size
|
width //= self.spatial_merge_size
|
||||||
|
|
||||||
# Calculate lengths and indices same as before
|
# calculate the length of the text and image tokens
|
||||||
text_length = next_vision_pos - current_pos
|
text_length = next_vision_pos - current_pos
|
||||||
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
start_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
)
|
||||||
|
|
||||||
# Text position ids
|
# text position ids
|
||||||
text_pos_ids = torch.arange(text_length, device=d)
|
text_pos_ids = torch.arange(text_length, device=d)
|
||||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||||
llm_pos_ids_list.append(text_pos_ids)
|
llm_pos_ids_list.append(text_pos_ids)
|
||||||
|
|
||||||
# Vision position ids
|
# vision position ids
|
||||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||||
height * width
|
height * width
|
||||||
)
|
)
|
||||||
|
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
.repeat_interleave(width)
|
.repeat_interleave(width)
|
||||||
.repeat(time_steps)
|
.repeat(time_steps)
|
||||||
)
|
)
|
||||||
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
|
w_indices = torch.arange(width, device=d).repeat(
|
||||||
|
height * time_steps
|
||||||
|
)
|
||||||
|
|
||||||
vision_pos_ids = (
|
vision_pos_ids = (
|
||||||
torch.stack([t_indices, h_indices, w_indices])
|
torch.stack([t_indices, h_indices, w_indices])
|
||||||
|
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
# Handle remaining text if any
|
# Handle remaining text if any
|
||||||
if current_pos < batch_input_ids.size(1):
|
if current_pos < batch_input_ids.size(1):
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
)
|
||||||
text_len = batch_input_ids.size(1) - current_pos
|
text_len = batch_input_ids.size(1) - current_pos
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
video_pixel_values: torch.FloatTensor = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
|
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
):
|
):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if video_pixel_values is not None and len(video_pixel_values) > 0:
|
||||||
|
vision_embeds = self.visual(
|
||||||
|
video_pixel_values, grid_thw=video_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
vision_token_mask = input_ids == self.video_token_id
|
||||||
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
vision_embeds = self.visual(
|
vision_embeds = self.visual(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
|
grid_thw=(
|
||||||
|
torch.cat([image_grid_thw, video_grid_thw])
|
||||||
|
if video_grid_thw is not None
|
||||||
|
else image_grid_thw
|
||||||
|
),
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Apply embeddings to both image and video tokens
|
# Apply embeddings to image tokens
|
||||||
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
|
vision_token_mask = input_ids == self.image_token_id
|
||||||
inputs_embeds[vision_token_mask] = vision_embeds
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
|
|
|
@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
from torchvision import io
|
||||||
|
import math
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -76,6 +78,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def video_text_replacement(processor, video_input, config) -> str:
|
||||||
|
if config.model_type == "qwen2_vl":
|
||||||
|
# num_pads = video_input['pixel_values'].size(0)
|
||||||
|
# num_pads = 1206
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# num_pads = 9556 + 10
|
||||||
|
num_pads = video_input.pixel_values.shape[0] // 4
|
||||||
|
padding = "<|video_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement_fixup(config, text: str) -> str:
|
def image_text_replacement_fixup(config, text: str) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
return text.replace(
|
return text.replace(
|
||||||
|
@ -138,29 +154,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||||
|
def smart_nframes(
|
||||||
|
fps: int,
|
||||||
|
nframes: int,
|
||||||
|
min_frames: int,
|
||||||
|
max_frames: int,
|
||||||
|
total_frames: int,
|
||||||
|
video_fps: int | float,
|
||||||
|
) -> int:
|
||||||
|
if nframes:
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
else:
|
||||||
|
min_frames = math.ceil(min_frames / 2) * 2
|
||||||
|
max_frames = math.floor(max_frames / 2) * 2
|
||||||
|
nframes = total_frames / video_fps * fps
|
||||||
|
nframes = min(max(nframes, min_frames), max_frames)
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
if not (2 <= nframes and nframes <= total_frames):
|
||||||
|
raise ValueError(
|
||||||
|
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
|
||||||
|
)
|
||||||
|
return nframes
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
video_pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
video_grid_thw: Optional[torch.Tensor]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -171,6 +217,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
# can make the image splits the same size. And we need the final
|
# can make the image splits the same size. And we need the final
|
||||||
# sizes to insert correct number of image tokens.
|
# sizes to insert correct number of image tokens.
|
||||||
images = []
|
images = []
|
||||||
|
videos = []
|
||||||
for r in requests:
|
for r in requests:
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
@ -192,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
images.append([image])
|
images.append([image])
|
||||||
elif chunk_type == "video":
|
elif chunk_type == "video":
|
||||||
if config.model_type == "qwen2_vl":
|
if config.model_type == "qwen2_vl":
|
||||||
pass
|
videos.append(chunk.video)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
|
@ -201,6 +248,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
|
video_inputs = None
|
||||||
|
if videos:
|
||||||
|
try:
|
||||||
|
tensor_videos = []
|
||||||
|
video = videos[0]
|
||||||
|
video_buffer = BytesIO(video.data)
|
||||||
|
video, _audio, info = io.read_video(
|
||||||
|
video_buffer,
|
||||||
|
start_pts=0.0,
|
||||||
|
end_pts=None,
|
||||||
|
pts_unit="sec",
|
||||||
|
output_format="TCHW",
|
||||||
|
)
|
||||||
|
total_frames, video_fps = video.size(0), info["video_fps"]
|
||||||
|
nframes = smart_nframes(
|
||||||
|
fps=30,
|
||||||
|
nframes=None,
|
||||||
|
min_frames=16,
|
||||||
|
max_frames=64,
|
||||||
|
total_frames=total_frames,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
||||||
|
video = video[idx]
|
||||||
|
tensor_videos.append(video)
|
||||||
|
video_inputs = processor.image_processor(
|
||||||
|
tensor_videos, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to process video: {e}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
video_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
image_id = 0
|
image_id = 0
|
||||||
|
@ -215,10 +297,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
elif chunk_type == "video":
|
||||||
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
|
full_text += video_text_replacement(processor, video_inputs, config)
|
||||||
text, _ = process_qwen_video(chunk.video)
|
|
||||||
full_text += text
|
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
|
@ -231,7 +311,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=not config.model_type == "paligemma",
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
|
@ -243,10 +323,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config
|
pb.requests, tokenizer, processor, config
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if video_inputs is not None:
|
||||||
|
if "pixel_values" in video_inputs:
|
||||||
|
batch.video_pixel_values = video_inputs["pixel_values"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
if "image_grid_thw" in video_inputs:
|
||||||
|
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
else:
|
||||||
|
batch.video_pixel_values = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "pixel_attention_mask" in image_inputs:
|
||||||
|
@ -263,6 +356,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||||
else:
|
else:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if "video_grid_thw" in image_inputs:
|
||||||
|
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
|
@ -270,6 +367,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -372,7 +470,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
if self.model.config.model_type == "qwen2_vl":
|
if self.model.config.model_type == "qwen2_vl":
|
||||||
if position_ids.dim() == 1 and batch.prefilling:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids, batch.image_grid_thw
|
input_ids, batch.image_grid_thw, batch.video_grid_thw
|
||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
|
@ -425,20 +523,26 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
video_pixel_values=batch.video_pixel_values,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
video_grid_thw=batch.video_grid_thw,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.video_pixel_values is not None:
|
||||||
|
batch.video_pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
if batch.image_sizes is not None:
|
if batch.image_sizes is not None:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if batch.video_grid_thw is not None:
|
||||||
|
batch.video_grid_thw = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
|
Loading…
Reference in New Issue